Hi,
I have a dataset with multiple related sequences. I want to model this with a hierarchical sequence model, which means that I have to iterate both through multiple sequences, as well as the items in the sequence.
Because nested scan is not possible (yet?), I use a Python for-loop for the outer loop, and scan for iterating through the individual sequence. Unfortunately this is still really slow to compile, so I want to ask:
Do there exist any ways to speed up the outer for-loop, perhaps via Jax primitives or some NumPyro settings?
Thanks!
Hi @julianstastny could you post pseudo code for your model (or its graphical representation). I hope we can convert loop to plate (rather than scan), which can significantly speed up your model.
My model is relatively complicated, so here is a simplified version which captures the core issue.
hyperparam = numpyro.sample("hyperparam", dist.Normal())
for i, session in enumerate(sessions): # About 100 conditionally independent sessions
state = numpyro.sample(f"{i}_param", dist.Normal(hyperparam))
for trial in session:
state = foo(state) # foo may include sample statements
p = bar(state)
obs = numpyro.sample(f"{i}_obs", dist.Bernoulli(p), obs=obs)
To speed this up I use scan.
hyperparam = numpyro.sample("hyperparam", dist.Normal())
for i, session in enumerate(sessions): # About 100 conditionally independent sessions
state = numpyro.sample(f"{i}_param", dist.Normal(hyperparam))
# Here I define transition function and use scan to loop through this.
So indeed, because in this model the sessions are conditionally independent given hyperparameters, a plate is in principle possible. Does NumPyro allow that?
(Not sure if relevant, but the sessions are not all of the same length)
Yes, it does. You can make a plate of sessions (using mask like in enum hmm example. Your model is very similar to that hmm example. Basically, you will need to
- mask the sessions
- scan over trial dimension (so make sure that the trial dimension is the most left dimension of the scan inputs)
- inside scan transition function, do a plate over session dimension
1 Like