Poutine.mask(mask=m): does the shape of the mask m here have to be the same with the observation data?


I think a broadcastable mask for the observation data could be ok?

I did not find the related description in the document.

mask ( fn=None , mask=None )[source]

Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.

Parameters: * fn – a stochastic function (callable containing Pyro primitive calls)

  • mask ( torch.ByteTensor ) – a {0,1} -valued masking tensor (1 includes a site, 0 excludes a site)
    Returns: stochastic function decorated with a MaskMessenger


The mask argument to poutine.mask must be broadcastable with the batch_shape of all sample statements it effects. In particular:

  1. The mask arg is unaware of sample statements’ .event_shape. E.g. in a MultivariateNormal distribution, each sample vector of shape event_shape would correspond to a single mask element).
  2. The two tensors mask and log_prob (= site["log_prob"] = site["fn"].log_prob(site["value"])) should be broadcastible, i.e. torch.broadcast_tensors(mask, log_prob) should not fail. This allows mask and log_prob to have different shapes, but the shapes must be compatible.

Feel free to submit a tiny PR clarifying the docs to your satisfaction :smile:


Thank you so much for the nice explanation! I can understand how the mask works much better now. :blush: