Hi,
I’m trying to estimate the decay rate in an exponential model. With a small learning rate (say 1e-5), my inference always fails as the loss term becomes inf. With a larger learning rate (1e-1), my inference fails as before most, but not all, of the time. I’m a little confused as to why this would be and I would appreciate any help. Thanks in advance (and apologies also in the likely event I’m just doing something silly)!
I’m generating data from a decaying exponential model and adding noise as shown below.
import torch
import numpy as np
import matplotlib.pyplot as plt
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
T1t = 1.3
times = torch.linspace(0., 8., 20)
f = 1 - torch.exp(- times / T1t)
noise_std = 0.1
y = torch.normal(f, noise_std)
plt.scatter(times.numpy(), y.numpy())
The above code is generating a slightly noisy curve which tends towards 1 as expected.
I now want to learn the value of the decay rate, T1t. I’ll define my model below. I have good reason to suspect the value of the decay parameter is 1.3 so I set this as the mean but set the std to be 5.
def model(data, times):
T1t_mean = pyro.param("T1t_mean", torch.tensor(1.3))
T1t_std = pyro.param("T1t_std", torch.tensor(5.), constraint=torch.distributions.constraints.positive)
T1t = pyro.sample("T1t", dist.Normal(T1t_mean, T1t_std))
f = 1 - torch.exp(- times / T1t)
noise_std = pyro.param("noise_std", torch.tensor(1.), constraint=torch.distributions.constraints.positive)
with pyro.plate("observe_data"):
pyro.sample("obs", pyro.distributions.Normal(f, noise_std), obs=data)
I set up my guide below:
def guide(data, times):
T1t_mean = pyro.param("T1t_mean_q", torch.tensor(1.3))
T1t_std = pyro.param("T1t_std_q", torch.tensor(5.), constraint=torch.distributions.constraints.positive)
T1t = pyro.sample("T1t", dist.Normal(T1t_mean, T1t_std))
f = 1 - torch.exp(- times / T1t)
noise_std = pyro.param("noise_std_q", torch.tensor(1.), constraint=torch.distributions.constraints.positive)
Finally, I’m running inference as below. The first two lines (calls to enable_validation and clear_param_store) were just copied from one of the tutorials.
# enable validation (e.g. validate parameters of distributions)
pyro.enable_validation(True)
# clear the param store in case we're in a REPL
pyro.clear_param_store()
# setup the optimizer
adam_params = {"lr": 0.005, "betas": (0.90, 0.999)}
optimizer = pyro.optim.Adam(adam_params)
# setup the inference algorithm
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())
# do gradient steps
n_steps = 5000
loss = []
for step in range(n_steps):
l = svi.step(y, times)
loss.append(l)
if step % 100 == 0:
print(f"Step = {step} Loss = {l:.3f}")