How to figure out how the ELBO related module works?

I am wondering whether there is a good method to figure out (e.g. debug method) how the clear process of ELBO related modules compute the loss. It seems that ELBO related modules are just like black-box modules for me and I don’t know how various ELBO module caculate the loss (e.g. given a data sample). Thanks!

Hi @Mucan, if you’re asking for background on the math implemented by Pyro’s ELBO variants, there are many references in the docstrings, e.g. for Trace_ELBO, and in the SVI tutorials, e.g. SVI Part III: ELBO Gradient Estimators. If you want to understand TraceEnum_ELBO in particular, you might also see our paper Tensor Variable Elimination for Plated Factor Graphs.

If you’re asking about the whole process of going from Pyro model to ELBO values, the MiniPyro tutorial is a good summary of the way Pyro works under the hood. The keyword arguments to ELBO implementation constructors are mostly shared and are documented in the base pyro.infer.ELBO class. The three main methods loss, differentiable_loss and loss_and_grads are summarized in the Trace_ELBO documentation.

If you’re asking about specific ELBO implementations, the best thing to do is start with the above suggestions and then just look at the source code. For example, Trace_ELBO isn’t that much more complicated than the MiniPyro example’s ELBO.

Thank you very much!