Variational inference with non-bijective stochastic functions

I am trying to perform inference with a specific model in pyro which leads me to a more general question about Pyro’s ability to perform inference in models where the log-pdf is not easily calculable.

I am trying to modify the SVI Part I tutorial to model a problem where a latent variable is the max of two other latent variables. It seems to me that this violates the “we can compute the pointwise log pdf pi” requirement that is mentioned in the tutorials for doing variational inference.

  1. Would I need to find a bijective approximation for the max so that the log-likelihood can be computed?
  2. For general stochastic functions, i.e., with a compositional structure p_i(x_i| f(z_i)) where the functions f are complicated and perhaps non-invertible, how does pyro perform inference? Or are such functions disallowed?

In particular consider the following model
z1 ~ Normal(mu1, sigma1)
z2 ~ Normal(mu2, sigma2)
x ~ Normal(z1^2 + z2^2, sigma=1)

And my inference task is to estimateE[z1|x = 5]. How does Pyro compute the pdf P(x, z1, z2) to perform variational inference?

Any help or pointers would be greatly appreciated.

Hi @innuo, I assume that x is your observation. First, we can compute p(z1, z2)=p(z1).p(z2) from their distribution informations. Then from x ~ Normal(z1^2 + z2^2, sigma=1), we can compute log_pdf at x, this gives you p(x | z1, z2). From p(x,z1,z2) = p(x|z1,z2).p(z1,z2), you get result.

To construct a Pyro model to obtain your task, you can see the tutorial (DEPRECATED) An Introduction to Inference in Pyro — Pyro Tutorials 1.8.4 documentation.

For those with similar questions, this was discussed at length in the Pyro GitHub issues at https://github.com/uber/pyro/issues/773.