How to get the joint distribution in Pyro?

I am reading the tutorial on SVI, and trying to figure out how to get the joint distribution log pθ(x,z) from the data. What troubles me is that all we have is the data x, not the latent variable z. Then why is it possible at all to construct log pθ(x,z) needed in ELBO maximization? [θ is subscript]

Also, what does the model construct in Pyro correspond to? The tutorial said clearly that the guide corresponds to qϕ(z). I think I am lost on the model.

The latent variable z is sampled from the guide. The third SVI tutorial explains the computation in details.

In Pyro, the model corresponds to p(x,z;θ) and the guide corresponds to q(z|x;ϕ).

What troubles me is that all we have is the data x, not the latent variable z.

You’re correct that you need z to compute log p(x,z;θ), and this is exactly what the guide is for, to construct a z from an x value.

Thanks for the clarifications! But why does model correspond to p(x,z;θ)? Does Pyro do something like producing a histogram from the data x and the latent variable z generated from q(z|x;ϕ) and then using the histogram to approximate p(x,z;θ)? How is z related (linked) to x? After all, to get the ELBO (the loss for optimization), we need approximations for both p(x,z;θ) and q(z|x;ϕ).

Pyro calculates p(x, z) = p(x|z)p(z), where z is a sample from q(z).

Hi fritzo, On second thought, the example in the tutorial suggests that the model construct actually corresponds to the posterior probability p_θ(z|x), because of the conditional “obs=” in the following line:
pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

If I missed something, could you point out a link that shows the model construct corresponds to the joint probability? Thanks!

In that example, obs is x. So when you have a sample f (which is z) from Beta distribution, you can compute p(obs | f). This multiplies with p(f) (which is the probability of Beta distribution) gives you the joint probability. The model gives you a sequence of probabilites p(f), p(obs | f). Under the hood, to compute joint probability, pyro multiplies the terms in that sequence together (or take their sum in log space). It is a common approach of various probabilistic programming languages.

In code, you can use pyro.poutine.trace(model).get_trace(data).log_prob_sum() to get joint log probability of a model.

If you use guide, f is sampled from q(f) instead of p(f).

Hope that help.

Hi fehiepsi, it helps a lot – things start to make sense now :slight_smile:. But I still have two questions.

Question 1:
What exactly does the pyro.sample with obs= do in the example?
pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

Specifically, does it generate a sample of Bernoulli distributed random variable or calculate a probability p(obs|f) only? If the latter, it is totally different from the same statement without obs=data[i].

In the example above, the condition seems to be that the latent variable is equal to f. But in another example, the tutorial says that obs= is the condition. How do I reconcile the contradictory explanation?

Question 2:
In the example, it is assumed that the latent variable f influences the data through the Bernoulli distribution with parameter f. That is nice and makes things simple. In a more realistic application, for example, in image recognition using a neural network, how to we find the distribution?


In the first example, by conditioning on obs=data, we have joint_prob = p(f, obs | obs=data) = p(obs=data | f).p(f). With obs, the value generated from pyro.sample statement of p(obs | f) is data. Without obs, the value generated from pyro.sample statement is a sample from Bernoulli(f) distribution.

I think that the best way to understand what pyro really does under the hood is to use pyro.poutine.
You can generate a trace by trace = pyro.poutine.trace(model).get_trace(data), then print out trace.nodes, which is a dict containing all information of param/sample sites generated from model.

About question 2, I guess you mean how to build a model? I don’t have an answer for it. It depends on knowledge, intuition, data,… I guess. A first step may be to learn what is the input and output (value generated from sample method) of a distribution. In that post, the author used Categorical distribution because that distribution is popular for a multi-class classification problem; its input is a value (logits) generated from neural network and its output is a sampled digit class.