Hi,
I’m new to numpyro and am hoping to use it for modelling ODE problems. As a learning exercise, I’ve tried inferring values from a simple nonlinear model (exponential decay). The data and model are below:
t = jnp.linspace(0,2,100)
data = 20 * jnp.exp(-5 * t) + 2 * jax.random.normal(random.PRNGKey(0), ([100]))
`def model(data):
t = jnp.linspace(0.0,2.0,100)
a = numpyro.sample("a", dist.Normal(0.0, 50.0))
b = numpyro.sample("b", dist.Normal(0.0, 50.0))
sigma = numpyro.sample("sigma", dist.Normal(0.0, 50.0))
mean = a * jnp.exp(-b * t)
with numpyro.plate("data", len(data)):
numpyro.sample("obs", dist.Normal(mean, sigma), obs=data)
This seems to work fine using an auto guide or NUTS sampler, but I can’t work out why my hand-written guide is not working. Here it is:
def guide(data):
t = jnp.linspace(0.0,2.0,100)
a_loc = numpyro.param("a_loc", 20.0)
a_scale = numpyro.param('a_scale', 10.0, constraint=constraints.positive)
b_loc = numpyro.param("b_loc", 5.0)
b_scale = numpyro.param('b_scale', 10.0, constraint=constraints.positive)
sigma_loc = numpyro.param('sigma_loc', 2.0)
sigma_scale = numpyro.param('sigma_scale', 10.0, constraint=constraints.positive)
a = numpyro.sample("a", dist.Normal(a_loc, a_scale))
b = numpyro.sample("b", dist.Normal(b_loc, b_scale))
sigma = numpyro.sample("sigma", dist.Normal(sigma_loc, sigma_scale))
mean = a * jnp.exp(-b * t)
with numpyro.plate("data", len(data)):
numpyro.sample("obs", dist.Normal(mean, sigma))
optimizer = numpyro.optim.Adam(step_size=0.05)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10000, data)
The result of this nan.
My understanding is that initial params set in the guide act as an initial guess – is this correct?
I anticipate I’ve made some trivial mistake based on some misunderstanding. Any advice on what is going wrong and how to fix it would be much appreciated.
Thanks,
Pavan