Kalman Filter in Numpyro

Hi folks,

I’m trying to convert some stan prototype (stan code here) to numpyro. The prototype was tested to run pretty well with stan but I will test it with jax and numpyro to see if I can further enhance performance in scalable solution. So far I feel there are two main challenges to me:

  1. the for loop within stan
  2. user specified likelihood

It seems to me 1. is doable with the scan function. However, when it combines to 2. , I’m a bit loss. I follow some previous posts and looks like I should create my own sampler. But I also wonder if I should create my own distribution instead. If that’s the case, is there a tutorial to teach the minimal work to inherit a Distribution class to make this work?

Thanks for the attention!

Edwin

You can use factor for custom likelihood. It should work withscan

Thanks for the suggestion. Looks like i can follow

def model():
    # ...
    if numpyro.get_mask() is not False:
        log_density = my_expensive_computation()
        numpyro.factor("foo", log_density)
    # ...

where my_expensive_computation can take some latent variables from other sampling block and output additional log_density? Am I right?

So I come up with this:

def local_level(y, a1, p1):
    obs_sigma = numpyro.sample(
        "obs_sigma", dist.TruncatedNormal(low=1e-10, loc=1.0, scale=1.0)
    )
    state_sigma = numpyro.sample(
        "state_sigma", dist.TruncatedNormal(low=1e-10, loc=1.0, scale=1.0)
    ) 
    
    a = a1
    p = p1

    log_prob = jnp.zeros(1)
    for t in range(n_steps):
        vt = y[t] - a
        Ft = p + jnp.square(obs_sigma)

        a = a + p * vt / Ft
        p = p * (1 - p) / Ft + jnp.square(state_sigma)
        log_prob += -0.5 * jnp.log(jnp.fabs(Ft) + jnp.square(vt) / Ft)
    
    numpyro.factor("kalman_filter_ll", log_prob)

kernel = NUTS(local_level)
mcmc = MCMC(kernel, num_warmup=3, num_samples=1, num_chains=num_chains, chain_method='parallel')

a1 = jnp.zeros(1)
p1 = jnp.ones(1)

mcmc.run(random.PRNGKey(0), y, a1=a1, p1=p1)

The run was stuck somewhere after the first (validation?) run somehow. Am I following this correctly?

Could you use scan instead of for loop? Could you print out shapes of some variables in the loop to see if they match your expectation?

1 Like

Thanks for the prompt response @fehiepsi . Somehow replacing the for loop with scan solves the problem. My thought was having a for loop was slow but could be an affordable approach to start a prototype. Somehow I was wrong. With scan, the model runs super fast! I guess the lesson is to always use scan? And your suggestion in testing my transition function with scan first is a great idea! Thank you!