Getting the meaning behind pyro.sample with obs

Hi there!
I’m pretty new to probabilistic learning and trying to understand how VI works.
I’m working through the SVI Part I: An Introduction to Stochastic Variational Inference in Pyro example and struggling to understand how this works.

Based on my understanding, during the VI we need to calculate the likelihood that our given parameters produce the data we can observe. In the above example the model is defined as

def model(data):
   # define the hyperparameters that control the beta prior
   alpha0 = torch.tensor(10.0)
   beta0 = torch.tensor(10.0)
   # sample f from the beta prior
   f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
   # loop over the observed data
   for i in range(len(data)):
       # observe datapoint i using the bernoulli
       # likelihood Bernoulli(f)
       pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

So the last comments says that we observe datapoints using the bernoulli likelihood.
What I don’t understand how pyro.sample is used here to calculate the likelihood? Wouldn’t the likelihood in this case that given the parameter f we want to know how probable it is the the resulting distribution dist.Bernoulli(f) produces our data[i]? I can’t see that this is calculated here because pyro.sample with an obs argument always returns the same value (makes sense because we want to sample conditioning on a data point, therefore the sample is the data point itself).

Therefore I also don’t understand how the VI works. If I getting the idea correctly, then during VI we sample parameters from the approximation distribution (defined in the function guide) and compute the ELBO loss, i.e. for the first part of this loss we need to run our model with the sampled parameters and calculate the likelihood, right? So where does this happen?

Thanks for any help!

Good question! Those sample, param,… statements are Pyro primitives, which are used to construct a generative model: i.e. a model to generate your data. It has nothing to do with evaluating probabilities, elbo, loss,… So those computations do not happen when you call the model.

If you want to get some additional information from the model like elbo, you can use some inference utilities like pyro.infer.Trace_ELBO. If you want to play with some of this stuff, you can start with Trace, like

trace = pyro.poutine.trace(model).get_trace(data)
logp = trace.log_prob_sum()
1 Like

So does the VI algorithm under the hoods call these functions in order to get the likelihood?

Yes. That’s right.

1 Like