The following code:
import numpyro as npy
from jax import lax, random
import jax.numpy as jnp
import numpyro.distributions.constraints as constraints
import numpyro.distributions as dist
pos = constraints.interval(0.0001, 9.223372036854776e18)
def guide():
a = npy.param("a", 1.0, constraint=pos)
b = npy.param("b", 2.0, constraint=pos)
scale = npy.sample("scale", uniform_maker(a, b))
c = npy.param("c", 0.0)
return dist.Normal(c, scale)
data = jnp.array([0.0, 1.0, 1.0, 2.0, 2.0, 3.0])
def model():
with npy.plate("observations", len(data)):
# print("observing")
npy.sample(f"result", guide(), obs=data)
def uniform_maker(a, b):
low = jnp.minimum(a, b)
high = jnp.maximum(a, b)
return dist.Normal(low, high)
optimizer = npy.optim.Adam(step_size=0.1)
svi = npy.infer.SVI(
model=model,
guide=guide,
optim=optimizer,
loss=npy.infer.Trace_ELBO(),
)
init_state = svi.init(random.PRNGKey(0))
num_steps = 100
svi.update(init_state)
state, losses = lax.scan(
lambda state, i: svi.update(state), init_state, jnp.arange(2000)
)
results = svi.get_params(state)
for k, v in results.items():
if jnp.isnan(v):
print(k, "is nan", v)
print(losses[-1])
print(losses)
Prints out the following result:
nan
[27.588427 44.92485 35.44729 ... nan 66.97971 nan]
In other words, the loss becomes NaN.
Why is this, and how can I avoid it?
Thanks.