Experimenting with scan
for state space models I stumbled upon something weird. It looks like inference fails when I scan over the observed variable, and succeeds if I scan over an index variable and use it to index the observed variable.
Please see the example of a linear regression model implemented using scan
below. I would like to scan over both xs
and ys
to avoid indexing them in f
, but this does not work. So if someone could could point out why my approach does not work and how to do it I would be very grateful. Thanks!
I’m using numpyro version 0.9.2.
import jax
import jax.numpy as np
import numpyro as npo
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as D
from numpyro.contrib.control_flow import scan
rng_key = jax.random.PRNGKey(1234)
def sample(model, *args):
mcmc = MCMC(
NUTS(model),
num_warmup=1000,
num_samples=1000,
num_chains=1,
progress_bar=True,
)
mcmc.run(*args)
return mcmc
def scan_linear_works(ys, num_forecast_steps=0):
a = npo.sample("a", D.Normal(0, 1))
b = npo.sample("b", D.Normal(0, 0.1))
sigma = npo.sample("sigma", D.HalfNormal(0.1))
xs = np.arange(ys.shape[0])
def f(carry, x):
i = carry
with npo.plate("time", ys.shape[0], dim=-1):
mu = npo.deterministic("mu", a + b * x)
npo.sample("y", D.Normal(mu, sigma), obs=ys[i])
return i+1, None
scan(f, np.array([0]), xs)
def scan_linear_broken(ys, num_forecast_steps=0):
a = npo.sample("a", D.Normal(0, 1))
b = npo.sample("b", D.Normal(0, 0.1))
sigma = npo.sample("sigma", D.HalfNormal(0.1))
xs = np.arange(ys.shape[0])
def f(state, y):
i = state
with npo.plate("time", ys.shape[0], dim=-1):
mu = npo.deterministic("mu", a + b * xs[i])
npo.sample("y", D.Normal(mu, sigma), obs=y)
return i+1, None
scan(f, np.array([0]), ys)
true_a = -0.75
true_b = 0.5
xx = np.arange(20)
yy = true_a + true_b*xx + jax.random.normal(rng_key, xx.shape)
# recovers true parameters
sample(scan_linear_works, rng_key, yy).print_summary()
# fails to recover true parameters
sample(scan_linear_broken, rng_key, yy).print_summary()