Clarification on enumeration for discrete latent variable models


#1

Hi,

Quoting the enumeration tutorial “Note that we’ve used “parallel” enumeration to enumerate along a new tensor dimension. This is cheap and allows Pyro to parallelize computation, but requires downstream program structure to avoid branching on the value of z. To support dynamic program structure, you can instead use “sequential” enumeration”.

I believe this means that if I am making some downstream sampling decisions based on an upstream discrete random variable, I have to use “sequential” enumeration. Is that correct?

Say, in my guide, I have a Gumbel-Softmax approximation of the discrete latent variable. The model still has the OneHotCat. latent variable, which needs to be summed out in order to compute the expected complete model likelihood (first term) in the ELBO. Should I use “parallel” or “sequential” enumeration? I believe I can use “parallel”. Is that correct?

A bit confused about the enumeration stuff.

Thank you for your comments.


#2

if I am making some downstream sampling decisions based on an upstream discrete random variable

Correct, in this case you’d need to use sequential enumeration. In some cases you can rewrite your model using poutine.mask to vectorize over a batch of decisions.

Gumbel-Softmax

Hmm the goal of automatic enumeration is to avoid high-variance tricks like Gumbel-Softmax and use exact discrete enumeration, which is actually pretty cheap on parallel hardware. I don’t think you’d want to combine the two.


#3

Thank you.


#4

If i understand correctly.

I can do something like this:

b = pyro.sample("b", Bernoulli(z)) # shape: (bsz)
with poutine.mask(b):
    z = pyro.sample("z", Normal(z_loc, z_scale)) # shape: (bsz, dim)

and enumerate Bernoulli RV in parallel. Is that correct?


#5

That’s right @samk, you can enumerate in parallel. I believe you’ll also need to

  1. use pyro.plate("event", dim) or .to_event(1) to ensure shapes are correct and
  2. provide a second sample statement for the pyro.mask(~b) branch