Experimenting with scan for state space models I stumbled upon something weird. It looks like inference fails when I scan over the observed variable, and succeeds if I scan over an index variable and use it to index the observed variable.
Please see the example of a linear regression model implemented using scan below. I would like to scan over both xs and ys to avoid indexing them in f, but this does not work. So if someone could could point out why my approach does not work and how to do it I would be very grateful. Thanks!
I’m using numpyro version 0.9.2.
import jax
import jax.numpy as np
import numpyro as npo
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as D
from numpyro.contrib.control_flow import scan
rng_key = jax.random.PRNGKey(1234)
def sample(model, *args):
mcmc = MCMC(
NUTS(model),
num_warmup=1000,
num_samples=1000,
num_chains=1,
progress_bar=True,
)
mcmc.run(*args)
return mcmc
def scan_linear_works(ys, num_forecast_steps=0):
a = npo.sample("a", D.Normal(0, 1))
b = npo.sample("b", D.Normal(0, 0.1))
sigma = npo.sample("sigma", D.HalfNormal(0.1))
xs = np.arange(ys.shape[0])
def f(carry, x):
i = carry
with npo.plate("time", ys.shape[0], dim=-1):
mu = npo.deterministic("mu", a + b * x)
npo.sample("y", D.Normal(mu, sigma), obs=ys[i])
return i+1, None
scan(f, np.array([0]), xs)
def scan_linear_broken(ys, num_forecast_steps=0):
a = npo.sample("a", D.Normal(0, 1))
b = npo.sample("b", D.Normal(0, 0.1))
sigma = npo.sample("sigma", D.HalfNormal(0.1))
xs = np.arange(ys.shape[0])
def f(state, y):
i = state
with npo.plate("time", ys.shape[0], dim=-1):
mu = npo.deterministic("mu", a + b * xs[i])
npo.sample("y", D.Normal(mu, sigma), obs=y)
return i+1, None
scan(f, np.array([0]), ys)
true_a = -0.75
true_b = 0.5
xx = np.arange(20)
yy = true_a + true_b*xx + jax.random.normal(rng_key, xx.shape)
# recovers true parameters
sample(scan_linear_works, rng_key, yy).print_summary()
# fails to recover true parameters
sample(scan_linear_broken, rng_key, yy).print_summary()