Is there a way to marginalize out a continuous latent variable in the model (assuming I cannot / do not want to do this analytically)? The ‘infer={enumerate: parallel}’ option I’m assuming only works for discrete latents which can be fully enumerated.
What if I wanted to approximately marginalize over a continuous latent variable that appears in the model by taking a few samples (and what would need to happen with the ELBO)? I would like to do this so that I can leave that sample statement out of the guide. Does Pyro have some way to do this?
Hi, you can pass a "num_samples"
argument along with "enumerate": "parallel"
in infer
to tell TraceEnum_ELBO
to approximately marginalize a variable with importance samples from the prior:
# suppose this site appears only in the model
pyro.sample("varname", dist, infer={
"enumerate": "parallel", "expand": True, "num_samples": 10})
Caveats: first, this will only produce correct gradient estimates if dist
is fully reparametrized, i.e. has an rsample
method, though this is true of most common continuous distributions such as Normal
. Second, this will cause every descendant (not just every child, as in exact enumeration with "expand": False
) of varname
to have an extra dimension of size num_samples
in its log_prob
and sample tensors. If varname
doesn’t have many descendants, this probably isn’t a big deal. Finally, the same restrictions on model structure and model-guide alignment that apply to exact enumeration apply here (outlined in the enumeration tutorial).
Note that you can also take multiple local samples in the guide for a variable appearing in model and guide in the same way, in order to reduce ELBO estimator variance.
If you have many such variables (for example, a latent Markov chain), you can try setting "expand": False
, which will use an experimental Tensor Monte Carlo estimator that avoids propagating sample dimensions beyond each variable’s children at the cost of extra variance in the overall ELBO estimate. If you want to work with non-reparametrized distributions or experiment with a different IWAE-style ELBO, you could try the Tensor Monte Carlo lower bound pyro.infer.TraceTMC_ELBO
.
1 Like