I have a more complex markov model based on the HMM example in the pyro docs, and I’m trying to use the vectorize_particles=True
option to avoid a runtime slowdown of about O(num_particles) with SVI using TraceEnum_ELBO
or JitTraceEnum_ELBO
.
However, I get weird indexing errors when I add this flag, and I can replicate a similar error with the HMM tutorial code by adding kwargs to the ELBO:
Elbo(
num_particles=2,
vectorize_particles=True,
...
)
Running python hmm.py
with the default arg values then produces the error trace:
/path/to/hmm.py in model_1(sequences, lengths, args, batch_size, include_prior)
213 pyro.sample(
214 "y_{}".format(t),
--> 215 dist.Bernoulli(probs_y[x.squeeze(-1)]),
216 obs=sequences[batch, t],
217 )
IndexError: index 7 is out of bounds for dimension 0 with size 2
If I print out the shapes with and without vectorizing particles, I get:
>>> probs_y.shape
torch.Size([2, 1, 1, 16, 51]) # vectorize_particles=True
torch.Size([16, 51]) # vectorize_particles=False
>>> x.squeeze(-1).shape
torch.Size([16, 1, 1]) # vectorize_particles=True
torch.Size([16, 1, 1]) # vectorize_particles=False
I’ve read through the Pyro docs on tensor shapes, but I’m still learning the ropes of Pyro, so I might be missing something simple… When I use the print_shapes
argument for the HMM script without vectorized particles, it looks like the batch dims are: 16, 16, 8, 51
, which I believe correspond to: <hidden dim at time t>, <hidden dim at t-1>, batch_size, data_dim
I think I get what’s happening in the default hmm script - probs_y
is the emission probabilities so it has a fixed shape of (hidden_dim, data_dim)
, whereas x
is a singleton getting enumerated over 16 possible values, giving it the shape of (16, 1, 1)
. However, I don’t understand why probs_y
gets three extra dimensions added (2, 1, 1, ...)
instead of just (a) one, for the number of particles (2, ...)
, or (b) four extra dimensions instead of three, corresponding to (num_particles, <hidden_dim t>, <hidden dim t-1>, batch_size, ...)
.
Can someone please help me understand what’s going on here, and how/if the HMM example can be modified to handle vectorize_particles=True
?