# Clarification on enumeration for discrete latent variable models

Hi,

Quoting the enumeration tutorial “Note that we’ve used “parallel” enumeration to enumerate along a new tensor dimension. This is cheap and allows Pyro to parallelize computation, but requires downstream program structure to avoid branching on the value of z. To support dynamic program structure, you can instead use “sequential” enumeration”.

I believe this means that if I am making some downstream sampling decisions based on an upstream discrete random variable, I have to use “sequential” enumeration. Is that correct?

Say, in my guide, I have a Gumbel-Softmax approximation of the discrete latent variable. The model still has the OneHotCat. latent variable, which needs to be summed out in order to compute the expected complete model likelihood (first term) in the ELBO. Should I use “parallel” or “sequential” enumeration? I believe I can use “parallel”. Is that correct?

A bit confused about the enumeration stuff.

if I am making some downstream sampling decisions based on an upstream discrete random variable

Correct, in this case you’d need to use sequential enumeration. In some cases you can rewrite your model using poutine.mask to vectorize over a batch of decisions.

Gumbel-Softmax

Hmm the goal of automatic enumeration is to avoid high-variance tricks like Gumbel-Softmax and use exact discrete enumeration, which is actually pretty cheap on parallel hardware. I don’t think you’d want to combine the two.

Thank you.

If i understand correctly.

I can do something like this:

``````b = pyro.sample("b", Bernoulli(z)) # shape: (bsz)
z = pyro.sample("z", Normal(z_loc, z_scale)) # shape: (bsz, dim)
``````

and enumerate Bernoulli RV in parallel. Is that correct?

That’s right @samk, you can enumerate in parallel. I believe you’ll also need to

1. use `pyro.plate("event", dim)` or `.to_event(1)` to ensure shapes are correct and
2. provide a second sample statement for the `pyro.mask(~b)` branch

May I ask, what does 2) mean in here? How can we provide a second sample statement? For which parameter?

I am having a similar problem with Bernoulli, where it only can give [2,1] dimension, and I cannot see its outputs as a vector, therefore cannot feed it to encoder_z.
To show in a small code:

``````def guide(self,xs,ys=None):
with pyro.plate('nano',20):
if ys is None:
alpha = self.encoder_y.forward(xs)
alpha = alpha.squeeze(1)
print("alpha shape from the encoder_y", alpha.shape)
assert (alpha.shape == (20,))
ys = pyro.sample("y", dist.Bernoulli(alpha))
print("Bernoulli y shape is", ys.shape)
print("Bernoulli y values are", ys)
loc,scale = self.encoder_z.forward(xs,ys)
pyro.sample("z",dist.Normal(loc,scale).to_event(1))
``````

And it gives:

alpha shape from the encoder_y torch.Size([20])
Bernoulli y shape is torch.Size([2, 1])
Bernoulli y values are tensor([[0.],
[1.]])

However the Bernoulli shape should be [20,1]. (20 is my batch size ). I see that [2,1] is the shortened version of possible values [0,1]. But i cannot use the sampling in my encoder_z(xs,ys) as the shape is not proper.

If I remove the dim from my plate, and add to_event(1) to Bernoulli, then it samples 20 times.

``````def guide(self,xs,ys=None):
with pyro.plate('nano'):
if ys is None:
alpha = self.encoder_y.forward(xs)
alpha = alpha.squeeze(1)
print("alpha shape from the encoder_y", alpha.shape)
assert (alpha.shape == (20,))
ys = pyro.sample("y", dist.Bernoulli(alpha).to_event(1))
print("Bernoulli to_event y shape is", ys.shape)
print("Bernoulli to_event y values are", ys)
loc,scale = self.encoder_z.forward(xs,ys)
pyro.sample("z",dist.Normal(loc,scale).to_event(1))
``````

And then it gives Bernoulli values in a vector:

alpha shape from the encoder_y torch.Size([20])
Bernoulli to_event y shape is torch.Size([20])
Bernoulli to_event y values are tensor([1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.])

I do not understand why did we need `to_event(1)` in here?
Because it gives the same output as `to_event(1)` when I use `to_event()` as empty…
If I put `to_event(1)` in the `y sampling` in the `guide`, then I need to do the same (adding` to_event(1)`) for the` model y sampling`, otherwise the event shapes would be different.

But then do I need to make `poutine.mask`?

I use parallel enumeration too:

``````ssvae.guide = config_enumerate(ssvae.guide,"parallel", expand=False)
pyro.clear_param_store()
svi = SVI(ssvae.model, ssvae.guide, optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1,strict_enumeration_warning=True))
``````

I am quite confused with enumeration of discrete variables and when `to_event()` is necessary, I would appreciate some insight!

2 Likes