SVI returns NaN loss

The following code:

import numpyro as npy
from jax import lax, random
import jax.numpy as jnp

import numpyro.distributions.constraints as constraints

import numpyro.distributions as dist

pos = constraints.interval(0.0001, 9.223372036854776e18)


def guide():
    a = npy.param("a", 1.0, constraint=pos)
    b = npy.param("b", 2.0, constraint=pos)
    scale = npy.sample("scale", uniform_maker(a, b))
    c = npy.param("c", 0.0)
    return dist.Normal(c, scale)


data = jnp.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0])


def model():
    with npy.plate("observations", len(data)):
        # print("observing")
        npy.sample(f"result", guide(), obs=data)


def uniform_maker(a, b):
    low = jnp.minimum(a, b)
    high = jnp.maximum(a, b)
    return dist.Normal(low, high)


optimizer = npy.optim.Adam(step_size=0.1)
svi = npy.infer.SVI(
    model=model,
    guide=guide,
    optim=optimizer,
    loss=npy.infer.Trace_ELBO(),
)
init_state = svi.init(random.PRNGKey(0))

num_steps = 100
svi.update(init_state)
state, losses = lax.scan(
    lambda state, i: svi.update(state), init_state, jnp.arange(2000)
)
results = svi.get_params(state)
for k, v in results.items():
    if jnp.isnan(v):
        print(k, "is nan", v)
print(losses[-1])
print(losses)

Prints out the following result:

nan
[27.588427 44.92485  35.44729  ...       nan 66.97971        nan]

In other words, the loss becomes NaN.

Why is this, and how can I avoid it?

Thanks.

The first thing to try is to reduce your learning rate. Try 0.001, and if that doesn’t NAN, try 0.01.

Thanks for the suggestion. I tried various learning rates (0.1, 0.01, 0.000001, 2.225E-307, 0) and all of them produce NaN.

Of note is that it doesn’t always produce NaN: rather each evaluation may or may not produce NaN.

have you tried using a less extreme constraint? e.g. 1000 instead of 10^18?

a unit constraint will be implemented with something like a sigmoid. so somewhere in the code you’ll effectively have something like 1.0e18 * sigmoid(...). you can imagine why that might be a bad idea

Good tip. Unfortunately, doesn’t fix the problem :frowning:.

Here’s the code I’m working with right now:

import numpyro as npy
from jax import lax, random
import jax.numpy as jnp

import numpyro.distributions.constraints as constraints

import numpyro.distributions as dist

pos = constraints.interval(0.0001, 1000)


def guide():
    a = npy.param("a", 1.0, constraint=pos)
    b = npy.param("b", 2.0, constraint=pos)
    scale = npy.sample("scale", uniform_maker(a, b))
    c = npy.param("c", 0.0)
    return dist.Normal(c, scale)


data = jnp.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0])


def model():
    with npy.plate("observations", len(data)):
        # print("observing")
        npy.sample(f"result", guide(), obs=data)


def uniform_maker(a, b):
    low = jnp.minimum(a, b)
    high = jnp.maximum(a, b)
    return dist.Normal(low, high)


optimizer = npy.optim.Adam(step_size=0.001)
svi = npy.infer.SVI(
    model=model,
    guide=guide,
    optim=optimizer,
    loss=npy.infer.Trace_ELBO(),
)
init_state = svi.init(random.PRNGKey(0))

num_steps = 100
svi.update(init_state)
state, losses = lax.scan(
    lambda state, i: svi.update(state), init_state, jnp.arange(2000)
)
results = svi.get_params(state)
for k, v in results.items():
    if jnp.isnan(v):
        print(k, "is nan", v)
count = 0
for loss in losses:
    print(loss)
    if jnp.isnan(loss):
        count += 1

print(count, "nan losses")

Which outputs:

673 nan losses

i’m not sure if this modeling setup makes much sense.
why is the guide inside of the model? what exactly are you trying to accomplish?

I will confess I don’t know much about probabilistic programming. From my understanding (from what I saw from the tutorial) the model is just the guide + the observations.

I realize this is probably not the case now, but I am not sure exactly that the “model” actually is.

I am trying to find the parameters a, b, and c of the distribution:
scale ~ Uniform(a, b)
Normal(c, scale)

it might be helpful to look at this resource to help you figure out how your model should be formulated. once you have a better idea of what your goal it is will be easier for us to help you

Thank you for the book recommendation. I went through and read the first two chapters.

I understand that the model is the prior, while the guide is what is actually being sampled.

I see the use for having them be different. But I don’t understand why they cannot be the same. Indeed, if you look at the sample here then it appears to me that the model (prior) is just the guide with two differences:

  1. Observations are added
  2. Params are transformed into constants

Is this view wrong?

In case it was the weird recursion or the params-in-model causing my issues, I went ahead and made a model more in the style of what I found in the documentation:

def guide():
    a = npy.param("a", 1.0, constraint=pos)
    b = npy.param("b", 2.0, constraint=pos)
    scale = npy.sample("scale", uniform_maker(a, b))
    c = npy.param("c", 0.0)
    return dist.Normal(c, scale)


data = jnp.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0])


def model():
    with npy.plate("observations", len(data)):
        a = 1.0
        b = 2.0
        scale = npy.sample("scale", uniform_maker(a, b))
        c = 0.0
        norm = dist.Normal(c, scale)
        npy.sample(f"result", norm, obs=data)


def uniform_maker(a, b):
    low = jnp.minimum(a, b)
    high = jnp.maximum(a, b)
    return dist.Normal(low, high)

This got me the same output.

Is there some key fact about the guide I’m misunderstanding? Are there certain distributions that cannot be guides?

Hi @GUIpsp,
Maybe it would be helpful to use different names than ‘model’ and ‘guide’. The model should represent prior information, that is it should encode everything you know about your latent variables even before seeing any data. The guide in Pyro represents an approximate posterior, i.e. it is trained by SVI to encode a combination of all information in the prior (the model) plus all information in the data. Schematically,

posterior = prior + data

In a strict Bayesian paradigm, all parameters should be in the guide because only the posterior is learnable from data; the model should have no parameters. In practice we sometimes add parameters to the model, but that is really slang for “we have zero prior information and seek maximum-likelihood point estimates of latent variables”. That is, when the model has parameters, we are being non-Bayesian about that part of the model.

It might also help to look at some more complex (model,guide) pairs. In simple models with a single latent variable, often an optimal guide could look exactly like the model, but (as you say) with fixed parameters in the model replaced by learnable parameters in the guide. However in more complex models, there can be dependencies between latent variables in the guide even when those variables are independent in the model; thus in general guides represent more complex distributions than models. Again that is because roughly “posterior = prior + data” or “guide = model + data”.

Here’s a simple example model:

def model(data):
    x = pyro.sample("x", dist.Normal(0, 1))
    y = pyro.sample("y", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(x + y, 1), obs=data)

Notice that x and y are independent in the model (the prior) but are negatively correlated in the posterior. Here is guide that turns out to be optimal (once its parameters have been fit):

def guide(data):
    loc = pyro.param("loc", torch.zeros(2))
    scale_tril = pyro.param("scale_tril", torch.eye(2),
                            constraint=constraints.lower_cholesky)
    xy = pyro.sample("xy", dist.MultivariateNormal(loc, scale_tril=scale_tril),
                     infer={"is_auxiliary": True})
    x, y = xy.unbind(-1)
    pyro.sample("x", dist.Delta(x))
    pyro.sample("y", dist.Delta(y))

Notice that this guide comes from a richer family than the model.
Cheers.

Thank you so much for your in depth response.

Two further questions:

  • Can an incorrect guide cause the behavior shown in the OP?
  • Given that I know the model (except the parameters), how can I produce a good guide for a certain model? Is this one of those problems where you just have to know that when you have a certain X distribution with another certain Y distribution you must use a Z distribution in the guide, for all combinations of X and Y?

Thanks

Another question:

Just by adding jnp.abs() around the parameter to Normal, the NaNs disappear. This makes no sense, at least to me, since that parameter should always be positive. Why does this happen?

In line 17:

    return dist.Normal(c, jnp.abs(scale))

You have a bug in uniform_maker, you should be using Uniform rather than Normal

@fritzo You are correct. Turns out I made a mistake when writing the repro. That’s embarassing.

Turns out the issue was a mixture of many different things in the program the repro was extracted from. The learning rate was too high, and I had to reduce my constraints down to ~100 as @martinjankowiak in order to get it to work, but it did work in the end.

Thank you all.