# Kalman Filter in Numpyro

Hi folks,

I’m trying to convert some stan prototype (stan code here) to numpyro. The prototype was tested to run pretty well with stan but I will test it with jax and numpyro to see if I can further enhance performance in scalable solution. So far I feel there are two main challenges to me:

1. the for loop within stan
2. user specified likelihood

It seems to me 1. is doable with the scan function. However, when it combines to 2. , I’m a bit loss. I follow some previous posts and looks like I should create my own sampler. But I also wonder if I should create my own distribution instead. If that’s the case, is there a tutorial to teach the minimal work to inherit a Distribution class to make this work?

Thanks for the attention!

Edwin

You can use factor for custom likelihood. It should work with`scan`

Thanks for the suggestion. Looks like i can follow

``````def model():
# ...
log_density = my_expensive_computation()
numpyro.factor("foo", log_density)
# ...
``````

where `my_expensive_computation` can take some latent variables from other sampling block and output additional `log_density`? Am I right?

So I come up with this:

``````def local_level(y, a1, p1):
obs_sigma = numpyro.sample(
"obs_sigma", dist.TruncatedNormal(low=1e-10, loc=1.0, scale=1.0)
)
state_sigma = numpyro.sample(
"state_sigma", dist.TruncatedNormal(low=1e-10, loc=1.0, scale=1.0)
)

a = a1
p = p1

log_prob = jnp.zeros(1)
for t in range(n_steps):
vt = y[t] - a
Ft = p + jnp.square(obs_sigma)

a = a + p * vt / Ft
p = p * (1 - p) / Ft + jnp.square(state_sigma)
log_prob += -0.5 * jnp.log(jnp.fabs(Ft) + jnp.square(vt) / Ft)

numpyro.factor("kalman_filter_ll", log_prob)

kernel = NUTS(local_level)
mcmc = MCMC(kernel, num_warmup=3, num_samples=1, num_chains=num_chains, chain_method='parallel')

a1 = jnp.zeros(1)
p1 = jnp.ones(1)

mcmc.run(random.PRNGKey(0), y, a1=a1, p1=p1)
``````

The run was stuck somewhere after the first (validation?) run somehow. Am I following this correctly?

Could you use scan instead of for loop? Could you print out shapes of some variables in the loop to see if they match your expectation?

1 Like

Thanks for the prompt response @fehiepsi . Somehow replacing the for loop with scan solves the problem. My thought was having a for loop was slow but could be an affordable approach to start a prototype. Somehow I was wrong. With scan, the model runs super fast! I guess the lesson is to always use scan? And your suggestion in testing my transition function with scan first is a great idea! Thank you!