I have a more complex markov model based on the HMM example in the pyro docs, and I’m trying to use the vectorize_particles=True option to avoid a runtime slowdown of about O(num_particles) with SVI using TraceEnum_ELBO or JitTraceEnum_ELBO.
However, I get weird indexing errors when I add this flag, and I can replicate a similar error with the HMM tutorial code by adding kwargs to the ELBO:
Elbo(
num_particles=2,
vectorize_particles=True,
...
)
Running python hmm.py with the default arg values then produces the error trace:
/path/to/hmm.py in model_1(sequences, lengths, args, batch_size, include_prior)
213 pyro.sample(
214 "y_{}".format(t),
--> 215 dist.Bernoulli(probs_y[x.squeeze(-1)]),
216 obs=sequences[batch, t],
217 )
IndexError: index 7 is out of bounds for dimension 0 with size 2
If I print out the shapes with and without vectorizing particles, I get:
>>> probs_y.shape
torch.Size([2, 1, 1, 16, 51]) # vectorize_particles=True
torch.Size([16, 51]) # vectorize_particles=False
>>> x.squeeze(-1).shape
torch.Size([16, 1, 1]) # vectorize_particles=True
torch.Size([16, 1, 1]) # vectorize_particles=False
I’ve read through the Pyro docs on tensor shapes, but I’m still learning the ropes of Pyro, so I might be missing something simple… When I use the print_shapes argument for the HMM script without vectorized particles, it looks like the batch dims are: 16, 16, 8, 51, which I believe correspond to: <hidden dim at time t>, <hidden dim at t-1>, batch_size, data_dim
I think I get what’s happening in the default hmm script - probs_y is the emission probabilities so it has a fixed shape of (hidden_dim, data_dim), whereas x is a singleton getting enumerated over 16 possible values, giving it the shape of (16, 1, 1). However, I don’t understand why probs_y gets three extra dimensions added (2, 1, 1, ...) instead of just (a) one, for the number of particles (2, ...), or (b) four extra dimensions instead of three, corresponding to (num_particles, <hidden_dim t>, <hidden dim t-1>, batch_size, ...).
Can someone please help me understand what’s going on here, and how/if the HMM example can be modified to handle vectorize_particles=True?