Infinite loss when fitting simple model


I’m trying to estimate the decay rate in an exponential model. With a small learning rate (say 1e-5), my inference always fails as the loss term becomes inf. With a larger learning rate (1e-1), my inference fails as before most, but not all, of the time. I’m a little confused as to why this would be and I would appreciate any help. Thanks in advance (and apologies also in the likely event I’m just doing something silly)!

I’m generating data from a decaying exponential model and adding noise as shown below.

import torch
import numpy as np
import matplotlib.pyplot as plt

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

T1t = 1.3
times = torch.linspace(0., 8., 20)
f = 1 - torch.exp(- times / T1t)

noise_std = 0.1
y = torch.normal(f, noise_std)
plt.scatter(times.numpy(), y.numpy())

The above code is generating a slightly noisy curve which tends towards 1 as expected.

I now want to learn the value of the decay rate, T1t. I’ll define my model below. I have good reason to suspect the value of the decay parameter is 1.3 so I set this as the mean but set the std to be 5.

def model(data, times):
    T1t_mean = pyro.param("T1t_mean", torch.tensor(1.3))
    T1t_std = pyro.param("T1t_std", torch.tensor(5.), constraint=torch.distributions.constraints.positive)
    T1t = pyro.sample("T1t", dist.Normal(T1t_mean, T1t_std))

    f = 1 - torch.exp(- times / T1t)

    noise_std = pyro.param("noise_std", torch.tensor(1.), constraint=torch.distributions.constraints.positive)
    with pyro.plate("observe_data"):
        pyro.sample("obs", pyro.distributions.Normal(f, noise_std), obs=data)

I set up my guide below:

def guide(data, times):
    T1t_mean = pyro.param("T1t_mean_q", torch.tensor(1.3))
    T1t_std = pyro.param("T1t_std_q", torch.tensor(5.), constraint=torch.distributions.constraints.positive)
    T1t = pyro.sample("T1t", dist.Normal(T1t_mean, T1t_std))

    f = 1 - torch.exp(- times / T1t)

    noise_std = pyro.param("noise_std_q", torch.tensor(1.), constraint=torch.distributions.constraints.positive)

Finally, I’m running inference as below. The first two lines (calls to enable_validation and clear_param_store) were just copied from one of the tutorials.

# enable validation (e.g. validate parameters of distributions)

# clear the param store in case we're in a REPL

# setup the optimizer
adam_params = {"lr": 0.005, "betas": (0.90, 0.999)}
optimizer = pyro.optim.Adam(adam_params)

# setup the inference algorithm
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())

# do gradient steps
n_steps = 5000
loss = []
for step in range(n_steps):
    l = svi.step(y, times)
    if step % 100 == 0:
        print(f"Step = {step}    Loss = {l:.3f}")

i’d generally recommend starting with a small variance for all distributions in the guide (here e.g. torch.tensor(0.01))

also does it really make sense to put a Normal prior on T1t which appears in a denominator?

1 Like

Hi Martin,

Thanks for this a lot for this, this has done the trick.

For intuition as to why, could it be because a large variance in the guide will lead to many unlikely possible values of latent variables, in this case T1t? Then, when computing the loss we calculate the ratio between the probabilities of p(z) and q(z), so when q(z) is very small this causes an issue?

And perhaps I can occasionally complete training without error with a larger learning rate because it can quickly place more probability on certain values, meaning that we were less likely to sample a value with such low probability it would cause the above issue?

For the Normal prior on T1t, probably not. I will change this.

well the exact details will vary depending on the model/guide but basically you will tend to get very stochastic (i.e. high variance) gradient estimates and that can lead you into strange/bad parts of parameter space