Clarification on enumeration for discrete latent variable models

Hi,

Quoting the enumeration tutorial “Note that we’ve used “parallel” enumeration to enumerate along a new tensor dimension. This is cheap and allows Pyro to parallelize computation, but requires downstream program structure to avoid branching on the value of z. To support dynamic program structure, you can instead use “sequential” enumeration”.

I believe this means that if I am making some downstream sampling decisions based on an upstream discrete random variable, I have to use “sequential” enumeration. Is that correct?

Say, in my guide, I have a Gumbel-Softmax approximation of the discrete latent variable. The model still has the OneHotCat. latent variable, which needs to be summed out in order to compute the expected complete model likelihood (first term) in the ELBO. Should I use “parallel” or “sequential” enumeration? I believe I can use “parallel”. Is that correct?

A bit confused about the enumeration stuff.

Thank you for your comments.

if I am making some downstream sampling decisions based on an upstream discrete random variable

Correct, in this case you’d need to use sequential enumeration. In some cases you can rewrite your model using poutine.mask to vectorize over a batch of decisions.

Gumbel-Softmax

Hmm the goal of automatic enumeration is to avoid high-variance tricks like Gumbel-Softmax and use exact discrete enumeration, which is actually pretty cheap on parallel hardware. I don’t think you’d want to combine the two.

Thank you.

If i understand correctly.

I can do something like this:

b = pyro.sample("b", Bernoulli(z)) # shape: (bsz)
with poutine.mask(b):
    z = pyro.sample("z", Normal(z_loc, z_scale)) # shape: (bsz, dim)

and enumerate Bernoulli RV in parallel. Is that correct?

That’s right @samk, you can enumerate in parallel. I believe you’ll also need to

  1. use pyro.plate("event", dim) or .to_event(1) to ensure shapes are correct and
  2. provide a second sample statement for the pyro.mask(~b) branch

May I ask, what does 2) mean in here? How can we provide a second sample statement? For which parameter?

I am having a similar problem with Bernoulli, where it only can give [2,1] dimension, and I cannot see its outputs as a vector, therefore cannot feed it to encoder_z.
To show in a small code:

def guide(self,xs,ys=None):
   with pyro.plate('nano',20):
     if ys is None:
       alpha = self.encoder_y.forward(xs)
       alpha = alpha.squeeze(1)
       print("alpha shape from the encoder_y", alpha.shape)
       assert (alpha.shape == (20,))
       ys = pyro.sample("y", dist.Bernoulli(alpha))
       print("Bernoulli y shape is", ys.shape)
       print("Bernoulli y values are", ys)
     loc,scale = self.encoder_z.forward(xs,ys)
     pyro.sample("z",dist.Normal(loc,scale).to_event(1))

And it gives:

alpha shape from the encoder_y torch.Size([20])
Bernoulli y shape is torch.Size([2, 1])
Bernoulli y values are tensor([[0.],
[1.]])

However the Bernoulli shape should be [20,1]. (20 is my batch size ). I see that [2,1] is the shortened version of possible values [0,1]. But i cannot use the sampling in my encoder_z(xs,ys) as the shape is not proper.

If I remove the dim from my plate, and add to_event(1) to Bernoulli, then it samples 20 times.

def guide(self,xs,ys=None):
   with pyro.plate('nano'):
     if ys is None:
       alpha = self.encoder_y.forward(xs)
       alpha = alpha.squeeze(1)
       print("alpha shape from the encoder_y", alpha.shape)
       assert (alpha.shape == (20,))
       ys = pyro.sample("y", dist.Bernoulli(alpha).to_event(1))
       print("Bernoulli to_event y shape is", ys.shape)
       print("Bernoulli to_event y values are", ys)
     loc,scale = self.encoder_z.forward(xs,ys)
     pyro.sample("z",dist.Normal(loc,scale).to_event(1))

And then it gives Bernoulli values in a vector:

alpha shape from the encoder_y torch.Size([20])
Bernoulli to_event y shape is torch.Size([20])
Bernoulli to_event y values are tensor([1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.])

I do not understand why did we need to_event(1) in here?
Because it gives the same output as to_event(1) when I use to_event() as empty…
If I put to_event(1) in the y sampling in the guide, then I need to do the same (adding to_event(1)) for the model y sampling, otherwise the event shapes would be different.

But then do I need to make poutine.mask?

I use parallel enumeration too:

ssvae.guide = config_enumerate(ssvae.guide,"parallel", expand=False) 
pyro.clear_param_store()
svi = SVI(ssvae.model, ssvae.guide, optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1,strict_enumeration_warning=True))

I am quite confused with enumeration of discrete variables and when to_event() is necessary, I would appreciate some insight!

2 Likes