# State Space Modeling - Using Jax.Scan

Hello,

I am trying to model a linear single degree of freedom through the NumPyro, and have a MCMC for the global variables of xi, wn, and local variables of x and v (including in the z vector and I do not want to marginalize them out). I previously used the following Pyro code, which was inefficient due to the Markov loop inside. Here is the code:

``````def model(Qnoise, Rnoise, delta, respAcc, inpAcc):
T = respAcc.shape[0]
# Parameters distributions
xi = pyro.sample('xi', dist.Uniform(0.0, 1.0))  # two-side constraint
wn = pyro.sample('wn', dist.LogNormal(-1.0, 0.4))

# Prior distributions
z_t = pyro.sample('z_0', dist.MultivariateNormal(torch.zeros(2), torch.diag(Qnoise)))
a_t = pyro.sample('a_0', dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[0])

for t in range(1, T):
## Transition ##
z_t = pyro.sample(f"z_{t}", dist.MultivariateNormal(fx(z_t.squeeze(0), xi.squeeze(0), wn.squeeze(0),
delta, inpAcc[t-1]), torch.diag(Qnoise)))
## Observation ##
a_t = pyro.sample(f"a_{t}", dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[t])
``````

Now, I want to transfer it to the NumPyro to leverage the jax computational features. I read the documentation of factorial HMM and the other discussions in the forum in order to use the “scan” for this purpose. Here is what I wrote so far, but I receive shape and type errors in different kinds. I am not sure what is the right way of defining the transition and using it. Also, I have a deterministic input of “inpAcc[t]” at each time step which I have no idea of how can I include it in my transition function. I would appreciate if you could help me with this migration.

``````def model(respAcc):
T = respAcc.shape[0]
N = 2 # dimension of the latent variables

# Parameters (revert back to constrained space)
xi = numpyro.sample('xi', dist.Uniform(0.0, 1.0))  # two-side constraint
wn = numpyro.sample('wn', dist.LogNormal(-1.0, 0.4))

# prior distributions
z0 = numpyro.sample('z0', dist.MultivariateNormal(jnp.zeros(N), jnp.eye(N)))
#     a0 = numpyro.sample('a0', dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-np.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[0])

# Propagate the dynamics forward using jax.lax.scan
def transition(z_prev, a_curr, xi, wn):
x1 = z_prev[0] + dt*x[1]
x2 = z_prev[1] + dt*(-exc - (2*xi*wn)*z_prev[1] - np.square(wn)*z_prev[0])
z_curr = numpyro.sample(f"z_{t}", dist.MultivariateNormal(np.stack((x1, x2), axis=0), 0.0001*jnp.eye(N)))
a_curr = numpyro.sample(f"a_{t}", dist.Normal(np.matmul(z_t.squeeze(0),np.array([[-np.square(wn)],[-(2*xi*wn)]])), 0.001), obs=a_curr)

return z_curr, (z_curr, a_curr)

_, (z, a) = jax.lax.scan(transition, z0, inpAcc[t], xi, wn)
numpyro.deterministic('z', z) #should I use something like this?
numpyro.sample('obs', dist.Normal(z.squeeze(-1), 0.1), obs=respAcc[t]) #should I use something like this?

z_jax = jnp.array(respAcc)
``````

Hi @Motahareh, we ported scan to work with `sample` primitives. You can also find other examples in time series forecast or discrete hmm.