State Space Modeling - Using Jax.Scan


I am trying to model a linear single degree of freedom through the NumPyro, and have a MCMC for the global variables of xi, wn, and local variables of x and v (including in the z vector and I do not want to marginalize them out). I previously used the following Pyro code, which was inefficient due to the Markov loop inside. Here is the code:

def model(Qnoise, Rnoise, delta, respAcc, inpAcc):
    T = respAcc.shape[0]
    # Parameters distributions
    xi = pyro.sample('xi', dist.Uniform(0.0, 1.0))  # two-side constraint
    wn = pyro.sample('wn', dist.LogNormal(-1.0, 0.4))
    # Prior distributions
    z_t = pyro.sample('z_0', dist.MultivariateNormal(torch.zeros(2), torch.diag(Qnoise)))
    a_t = pyro.sample('a_0', dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[0])
    for t in range(1, T):
        ## Transition ##
        z_t = pyro.sample(f"z_{t}", dist.MultivariateNormal(fx(z_t.squeeze(0), xi.squeeze(0), wn.squeeze(0),
                                                               delta, inpAcc[t-1]), torch.diag(Qnoise)))
        ## Observation ##
        a_t = pyro.sample(f"a_{t}", dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[t])

Now, I want to transfer it to the NumPyro to leverage the jax computational features. I read the documentation of factorial HMM and the other discussions in the forum in order to use the “scan” for this purpose. Here is what I wrote so far, but I receive shape and type errors in different kinds. I am not sure what is the right way of defining the transition and using it. Also, I have a deterministic input of “inpAcc[t]” at each time step which I have no idea of how can I include it in my transition function. I would appreciate if you could help me with this migration.

def model(respAcc):
    T = respAcc.shape[0]
    N = 2 # dimension of the latent variables
    # Parameters (revert back to constrained space)
    xi = numpyro.sample('xi', dist.Uniform(0.0, 1.0))  # two-side constraint
    wn = numpyro.sample('wn', dist.LogNormal(-1.0, 0.4))
    # prior distributions
    z0 = numpyro.sample('z0', dist.MultivariateNormal(jnp.zeros(N), jnp.eye(N)))
#     a0 = numpyro.sample('a0', dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-np.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[0])
    # Propagate the dynamics forward using jax.lax.scan
    def transition(z_prev, a_curr, xi, wn):
        x1 = z_prev[0] + dt*x[1]
        x2 = z_prev[1] + dt*(-exc - (2*xi*wn)*z_prev[1] - np.square(wn)*z_prev[0])        
        z_curr = numpyro.sample(f"z_{t}", dist.MultivariateNormal(np.stack((x1, x2), axis=0), 0.0001*jnp.eye(N)))
        a_curr = numpyro.sample(f"a_{t}", dist.Normal(np.matmul(z_t.squeeze(0),np.array([[-np.square(wn)],[-(2*xi*wn)]])), 0.001), obs=a_curr)

        return z_curr, (z_curr, a_curr)
    _, (z, a) = jax.lax.scan(transition, z0, inpAcc[t], xi, wn)
    numpyro.deterministic('z', z) #should I use something like this?
    numpyro.sample('obs', dist.Normal(z.squeeze(-1), 0.1), obs=respAcc[t]) #should I use something like this?

z_jax = jnp.array(respAcc)

Hi @Motahareh, we ported scan to work with sample primitives. You can also find other examples in time series forecast or discrete hmm.