How to get log p(z,x) (joint) and log q(z) when using autoguide.AutoDiagonalNormal()

How could we get log p(z,x) (joint) and log q(z) when using autoguide.AutoDiagonalNormal()? We could not find where to get these?

We would like to compute the ELBO manually. How does Pyro compute:
ELBO = -svi.step(x,y)
where svi = pyro.infer.SVI(model, guide, optimizer, loss=JitTraceGraph_ELBO() )?

Could you please help? Thank you!

Hi @pookie , here is the source code for TraceGraph (which is complicated) if you want to take a look. This example shows how to construct a custom ELBO. The example also shows you how to get log p(z,x) and log q(z). For example, to compute log q(z), you can use pyro.poutine.trace(guide).get_trace(*args, **kwargs).log_prob_sum() where args, kwargs are arguments of your model.

1 Like

Thank you very much @fehiepsi. This is very helpful! :smile:
It looks like it is not the same z samples when calling the trace for model (to get log p(z, data)) then the trace for guide (to get log q(z) or q(z|data)))? we could not get the exact same value as from svi.step(data), even with playing with Pyro random seed. However, for simple models, the values are near.

Edit: we use poutine.replay and that works. We got much better ELBO values. Thank you again, and Happy New Year! :tada: