Poutine.mask(mask=m): does the shape of the mask m here have to be the same with the observation data?

I think a broadcastable mask for the observation data could be ok?

I did not find the related description in the document.

mask ( fn=None , mask=None )[source]

Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.

Parameters: * fn – a stochastic function (callable containing Pyro primitive calls)

  • mask ( torch.ByteTensor ) – a {0,1} -valued masking tensor (1 includes a site, 0 excludes a site)
    Returns: stochastic function decorated with a MaskMessenger
1 Like

The mask argument to poutine.mask must be broadcastable with the batch_shape of all sample statements it effects. In particular:

  1. The mask arg is unaware of sample statements’ .event_shape. E.g. in a MultivariateNormal distribution, each sample vector of shape event_shape would correspond to a single mask element).
  2. The two tensors mask and log_prob (= site["log_prob"] = site["fn"].log_prob(site["value"])) should be broadcastible, i.e. torch.broadcast_tensors(mask, log_prob) should not fail. This allows mask and log_prob to have different shapes, but the shapes must be compatible.

Feel free to submit a tiny PR clarifying the docs to your satisfaction :smile:

1 Like

Thank you so much for the nice explanation! I can understand how the mask works much better now. :blush:

Hi all!

What would be the correct way of masking a MultivariateNormal distribution within a plate then?
e.g:

  • we have correllated observation Y = cat(y1, y2, y3), but only partially overlapping, with a mask obtained by torch.isnan(Y).
  • So we get a linear regression for the means and an LKJ prior for the scales, wrap in a pyro plate for the data length and do a sample(“obs”, MultivariateNormal(…), obs=Y), but any putine mask statement gives a broadcast error…

sorry for the necro but this is the closest I got to what I’m looking for.

Hi @shoubidouwah I’m not sure how to make mathematical sense of masking only part of a joint distribution, since the log density depends on all parameters. But you might try to split up the single MultivariateNormal over Y into three different multivariate normals whose means are linearly dependent:

y1 = pyro.sample("y1", MultivariateNormal(
    loc1,
    scale_tril=scale_11,
))
y2 = pyro.sample("y2", MultivariateNormal(
    loc2 + y1 @ scale_12,
    scale_tril=scale_22
))
y3 = pyro.sample("y3", MultivariateNormal(
    loc3 + y1 @ scale_13 + y2 @ scale_23,
    scale_tril=scale_33,
))
y = pyro.deterministic("y", torch.cat([y1, y2, y3], dim=-1))

Then you can individually mask each of the components.

Hi @fritzo ,
Thanks heaps for this quick answer.

I know the whole concept is weird. The data is a set of molecules with associated biochemical measurments that are known to be correllated (think solubility and lipophiliticy). Not all molecules have all measurments, and there is a large difference between the number of observationsfor each endpoint (eg: solubility is measured for 10k molecules, lipophilicity for 100k), which makes imputation tricky.

It does make sense to do as you say, I think. The point is to leverage a learned correllation on what overlapping data we have to enhance prediction.

I’ll try your solution and report the results.

Cheers again!