Speed up nested for-loop/scan in NumPyro model

Hi,

I have a dataset with multiple related sequences. I want to model this with a hierarchical sequence model, which means that I have to iterate both through multiple sequences, as well as the items in the sequence.

Because nested scan is not possible (yet?), I use a Python for-loop for the outer loop, and scan for iterating through the individual sequence. Unfortunately this is still really slow to compile, so I want to ask:

Do there exist any ways to speed up the outer for-loop, perhaps via Jax primitives or some NumPyro settings?

Thanks!

Hi @julianstastny could you post pseudo code for your model (or its graphical representation). I hope we can convert loop to plate (rather than scan), which can significantly speed up your model.

My model is relatively complicated, so here is a simplified version which captures the core issue.

hyperparam = numpyro.sample("hyperparam", dist.Normal())
for i, session in enumerate(sessions): # About 100 conditionally independent sessions
    state = numpyro.sample(f"{i}_param", dist.Normal(hyperparam))
    for trial in session:
        state = foo(state) # foo may include sample statements
        p = bar(state)
        obs = numpyro.sample(f"{i}_obs", dist.Bernoulli(p), obs=obs)

To speed this up I use scan.

hyperparam = numpyro.sample("hyperparam", dist.Normal())
for i, session in enumerate(sessions): # About 100 conditionally independent sessions
    state = numpyro.sample(f"{i}_param", dist.Normal(hyperparam))
    # Here I define transition function and use scan to loop through this.

So indeed, because in this model the sessions are conditionally independent given hyperparameters, a plate is in principle possible. Does NumPyro allow that?

(Not sure if relevant, but the sessions are not all of the same length)

Yes, it does. You can make a plate of sessions (using mask like in enum hmm example. Your model is very similar to that hmm example. Basically, you will need to

  • mask the sessions
  • scan over trial dimension (so make sure that the trial dimension is the most left dimension of the scan inputs)
  • inside scan transition function, do a plate over session dimension
1 Like