Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.
Parameters: * fn – a stochastic function (callable containing Pyro primitive calls)
mask (torch.ByteTensor) – a {0,1} -valued masking tensor (1 includes a site, 0 excludes a site)
Returns: stochastic function decorated with a MaskMessenger
The mask argument to poutine.mask must be broadcastable with the batch_shape of all sample statements it effects. In particular:
The mask arg is unaware of sample statements’ .event_shape. E.g. in a MultivariateNormal distribution, each sample vector of shape event_shape would correspond to a single mask element).
The two tensors mask and log_prob (= site["log_prob"] = site["fn"].log_prob(site["value"])) should be broadcastible, i.e. torch.broadcast_tensors(mask, log_prob) should not fail. This allows mask and log_prob to have different shapes, but the shapes must be compatible.
Feel free to submit a tiny PR clarifying the docs to your satisfaction
What would be the correct way of masking a MultivariateNormal distribution within a plate then?
e.g:
we have correllated observation Y = cat(y1, y2, y3), but only partially overlapping, with a mask obtained by torch.isnan(Y).
So we get a linear regression for the means and an LKJ prior for the scales, wrap in a pyro plate for the data length and do a sample(“obs”, MultivariateNormal(…), obs=Y), but any putine mask statement gives a broadcast error…
sorry for the necro but this is the closest I got to what I’m looking for.
Hi @shoubidouwah I’m not sure how to make mathematical sense of masking only part of a joint distribution, since the log density depends on all parameters. But you might try to split up the single MultivariateNormal over Y into three different multivariate normals whose means are linearly dependent:
I know the whole concept is weird. The data is a set of molecules with associated biochemical measurments that are known to be correllated (think solubility and lipophiliticy). Not all molecules have all measurments, and there is a large difference between the number of observationsfor each endpoint (eg: solubility is measured for 10k molecules, lipophilicity for 100k), which makes imputation tricky.
It does make sense to do as you say, I think. The point is to leverage a learned correllation on what overlapping data we have to enhance prediction.