Vectorize particles in discrete HMM example

I have a more complex markov model based on the HMM example in the pyro docs, and I’m trying to use the vectorize_particles=True option to avoid a runtime slowdown of about O(num_particles) with SVI using TraceEnum_ELBO or JitTraceEnum_ELBO.

However, I get weird indexing errors when I add this flag, and I can replicate a similar error with the HMM tutorial code by adding kwargs to the ELBO:

Elbo(
    num_particles=2,
    vectorize_particles=True,
    ...
)

Running python hmm.py with the default arg values then produces the error trace:

/path/to/hmm.py in model_1(sequences, lengths, args, batch_size, include_prior)
    213                     pyro.sample(
    214                         "y_{}".format(t),
--> 215                         dist.Bernoulli(probs_y[x.squeeze(-1)]),
    216                         obs=sequences[batch, t],
    217                     )

IndexError: index 7 is out of bounds for dimension 0 with size 2

If I print out the shapes with and without vectorizing particles, I get:

>>> probs_y.shape
torch.Size([2, 1, 1, 16, 51])  # vectorize_particles=True
torch.Size([16, 51])           # vectorize_particles=False
>>> x.squeeze(-1).shape
torch.Size([16, 1, 1])         # vectorize_particles=True
torch.Size([16, 1, 1])         # vectorize_particles=False

I’ve read through the Pyro docs on tensor shapes, but I’m still learning the ropes of Pyro, so I might be missing something simple… When I use the print_shapes argument for the HMM script without vectorized particles, it looks like the batch dims are: 16, 16, 8, 51, which I believe correspond to: <hidden dim at time t>, <hidden dim at t-1>, batch_size, data_dim

I think I get what’s happening in the default hmm script - probs_y is the emission probabilities so it has a fixed shape of (hidden_dim, data_dim), whereas x is a singleton getting enumerated over 16 possible values, giving it the shape of (16, 1, 1). However, I don’t understand why probs_y gets three extra dimensions added (2, 1, 1, ...) instead of just (a) one, for the number of particles (2, ...), or (b) four extra dimensions instead of three, corresponding to (num_particles, <hidden_dim t>, <hidden dim t-1>, batch_size, ...).

Can someone please help me understand what’s going on here, and how/if the HMM example can be modified to handle vectorize_particles=True?

cc @ordabayev

Hi @ejb

Those two additional dimension correspond to two plates: sequences (dim=-2) and tones (dim=-1). Vectorizing particles effectively adds an additional plate at dim=-3 (see https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/elbo.py#L186).

So for example when particles are not vectorized you get probs_y.shape = (16, 51) where both dims are event dims (dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2)).

When particles are vectorized then the model and guide are wrapped by an additional plate:

pyro.plate(
            "num_particles_vectorized",
            self.num_particles,
            dim=-self.max_plate_nesting,  # this is -3 in your case
):
    probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )

Note that batch dims -1 and -2 are already allocated for sequences and tones plates so num_particles_vectorized plate takes the next available dim (-3).

Same logic applies to x: first three dims from the left are already allocated (num_particles_vectorized=-3, sequences=-2, tones=-1) so it uses dim=-4 for enumeration.

Hope this helps, let me know if you have further questions.

1 Like

@ordabayev Thanks for you quick reply! That does help explain what’s going on.

I’m still having trouble understanding how to modify the HMM indexing to enable vectorizing particles, though. I think this comes down to pytorch tensor indexing that’s not specific to pyro, but I had trouble finding good, clear resources on complex tensor indexing like this.

The line probs_y[x.squeeze(-1)] uses a tensor of sampled/enumerated hidden state values x, to index a tensor probs_y of emissions, to retrieve emission probs associated with those hidden states. However, if probs_y has shape [2, 1, 1, 16, 51], how can we index this with a tensor of shape [16, 1, 1]? I think the expected output shape should be [2, 1, 1, 16, 1, 1, 51].

For a simpler indexing example, the example HMM code has indexing analogous to: torch.zeros(16,51)[torch.zeros(16, 1, 1).int()], where the final shape is [16, 1, 1, 51]. If I use shapes matching the vectorized case, torch.zeros(2,1,1,16,51)[torch.zeros(16, 1, 1).int()] has output shape [16, 1, 1, 1, 1, 16, 51], but this is wrong because it’s using hidden state indexes [0, 15] to index the particle dimension [0, 1], causing an indexing error when real indices are used instead of just zeros.

Do you have any suggestions on how to do this indexing? torch.index_select doesn’t seem to apply since it expects indexes to be a vector, not an N-order tensor, and I’m having trouble grokking torch.gather and how/if it could do this.

Let’s first figure out what should be the expected output shape of indexing probs_y with x. And then we can reverse engineer how we should index it.

Expected output is a tensor of Bernoulli probabilities of a shape [16, 2, 1, 51]:
dim=-4 is an enumeration dim for x (16 values)
dim=-3 is a particle vectorization (2 particles)
dim=-2 is allocated for sequences plate (1 value broadcasted over a batch)
dim=-1 is allocated for tones plate (51 values)

Now to achieve that you can use a Vindex operator.

with tones_plate as tdx:
        pyro.sample(
            "y_{}".format(t),
            dist.Bernoulli(Vindex(probs_y)[..., x, tdx]),
            obs=sequences[batch, t],
        )

tdx here is to capture the plate values - it is same as torch.arange(data_dim).

There is also a material in the documentation about the Vindex - Miscellaneous Ops — Pyro documentation.

1 Like

This is very helpful, thank you @ordabayev !

After making this change I got another error with indexing, this time at probs_x[x]. I think I figured out a solution using Vindex, but just wanted to confirm that this is correct and describe my thinking in case someone else might find this useful later.

x = pyro.sample(
    "x_{}".format(t),
    dist.Categorical(Vindex(probs_x)[..., x, :]),  # updated
    infer={"enumerate": "parallel"},
)

After making this change as well as what you suggested, the discrete HMM code runs with vectorized particles, and the loss seems to decrease following a similar pattern as non-vectorized particles without the change.

From the tensor shapes docs, we have: shape = sample_shape + batch_shape + event_shape.

probs_x has shape [2, 1, 1, 16, 16], with the first three dimensions being batch dimensions for the plates (particles, sequences, tones), and the last two being event dimensions for the 16x16 HMM transition matrix. x has shape [16, 1, 1, 1], with the first dimension being a sample shape for enumeration over hidden state values for \text{x}_{t-1}, and the last three dimensions being the same plate batch dims (particles, seq, tones).

Since we’re indexing probs_x to get a vector of transition probabilities, we should expect the final shape of probs_x[???] to be [16, 2, 1, 1, 16], where 16 is a sample dim for enumerating \text{x}_{t-1}, [2, 1, 1] are plate batch dims, and the last 16 index is the event dimension, since we need to use that to draw \text{x}_{t} \sim Categorical(...) .

Another little sanity check for me is that dist.Categorical(Vindex(probs_x)[..., x, :]).sample().shape == [16, 2, 1, 1], which checks out since we’re sampling a single \text{x}_{t}, enumerated over 16 \text{x}_{t-1} values and batched over the 3 plated dims.

1 Like

Glad it helped @ejb . It all looks correct to me!

1 Like