Question about the enumeration tutorial

  • What tutorial are you running?
    Inference with Discrete Latent Variables
  • What version of Pyro are you using?
  • Please link or paste relevant code, and steps to reproduce.
    In the Vindex part,
with data_plate:
    c = pyro.sample("c", dist.Categorical(torch.ones(4)))
    with feature_plate as vdx:                # Capture plate index.
        pc = Vindex(p)[vdx[..., None], c, :]  # Reshape it and use in Vindex.
        x = pyro.sample("x", dist.Categorical(pc),
                        obs=torch.zeros(5, 6, dtype=torch.long))

I know that vdx[…, None] create a new dimension of vdx, but why should we do this? And under what condition should we do this?


Short answer is that feature_plate declares conditional independence along the batch dimension dim=-2 and within that plate context batch dim=-2 of any probability distributions is “reserved” for the feature_plate. So you want dist.Categorical(pc) to have a batch and event shapes that align accordingly with dimensions of feature_plate (and data_plate at dim=-1). Vindex is a helper tool for this kind of advanced indexing. Hope this helps.

Thanks for the reply!
I read numpy’s vindex doc, and found that the new axis was used by the vindex for broadcast purpose.