I wanted to test out coding a state-space model in numpyro using lax.scan. I’m running into issues that make me suspect this is not supported — or perhaps I’m just getting something else wrong! Here’s my model:
def target(T=10, q=1, r=1, phi=0., beta=0.):
def transition(state, i):
x0, mu0 = state
x1 = numpyro.sample(f'x_{i}', dist.Normal(phi*x0, q))
mu1 = beta * mu0 + x1
y1 = numpyro.sample(f'y_{i}', dist.Normal(mu1, r))
return (x1, mu1), (x1, y1)
mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
y0 = numpyro.sample('y_0', dist.Normal(mu0, r))
_, xy = jax.lax.scan(transition, (x0, mu0), np.arange(1, T))
x, y = xy
return np.append(x0, x), np.append(y0, y)
This returns:
x [-1.1470195 -0.3285517 -0.3285517 -0.3285517 -0.3285517 -0.3285517
-0.3285517 -0.3285517 -0.3285517 -0.3285517]
y [-2.2391834 0.32762653 0.32762653 0.32762653 0.32762653 0.32762653
0.32762653 0.32762653 0.32762653 0.32762653]
It appears the sample statements within transition
only generate one random value, which is repeated in each iteration. When I try to use this model within Predictive
, I get an error:
prior = Predictive(target, posterior_samples = {}, num_samples = 10)
prior_samples = prior(PRNGKey(2), T=10)
UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state.
Details: Can't lift level Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)> to JaxprTrace(level=0/0).
I don’t need to get this model running, and understand that I could reparameterize it to generate all of the random variables outside the loop. I’m just wondering about more general state-space models with transitions that are not as easily re-parameterized: can one put sampling statements within a loop that is executed by lax.scan?
Thanks!