GPs: Variational Inference + Model Definition in VariationalGP

Hi All,

first of all thank you for the amazing probabilistic programming language, it is of great help in my research! I have a doubt regarding how variational inference is executed in VariationalGP models. In particular, I am not sure I completely understand the use of the sampling statements in the model definition here below:

        if self.whiten:
            identity = eye_like(self.X, N)
            pyro.sample("f",
                        dist.MultivariateNormal(zero_loc, scale_tril=identity)
                            .to_event(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(self.f_scale_tril)
            f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1)
        else:
            pyro.sample("f",
                        dist.MultivariateNormal(zero_loc, scale_tril=Lff)
                            .to_event(zero_loc.dim() - 1))

            f_scale_tril = self.f_scale_tril
            f_loc = self.f_loc

        f_loc = f_loc + self.mean_function(self.X)
        f_var = f_scale_tril.pow(2).sum(dim=-1)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)

Why are we not using the samples from the latent variable f in the definition of the likelihood (in later stages of model), but rather leave the task of defining the observed variable y entirely to the trainable parameters f_loc and f_scale_tril ? Specifically, in following a generative view of Gaussian Processes it would seem reasonable to use a sampled latent variable, e.g. fs = pyro.sample("f",...), in the likelihood of the model N(y| fs, sigma).

Also, I was wondering whether this could be related to the fact of using the following training scheme (as for the Gaussian Process Introduction on the documentation):

optimizer = torch.optim.Adam(gp.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
num_steps = 2500 if not smoke_test else 2
for i in range(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(gp.model, gp.guide)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

opposed to a more “pyro-classical” SVI training procedure as follows:

optimizer = pyro.optim.Adam({"lr": 0.01})
svi = SVI(gp.model, gp.guide, optimizer, loss=Trace_ELBO())
for i in range(1000):
    svi.step() 

In the first, if I understand correctly, we are optimizing the model parameters model.parameters(), which contain also the guide parameters f_loc and f_scale_tril (not explicitly pyro.params in the guide but rather trainable torch Parameters used in both model and guide). Does this mean we are not explicitly using a guide from which we can sample from through, for example, svi.run() + EmpiricalMarginal (which we could do if we used the second training scheme)?

My idea is that the second training scheme would somehow require the use of latent samples (i.e. fs = pyro.sample("f",...)) in defining the observations y to allow for a feasible posterior approximation through the guide.

Hope I managed to be sufficiently clear and thank you very much in advance for your help!

Hi @Daniele, both training scheme are equivalent (just different on where to store parameters). Actually the first version of gp uses second scheme together with pyro.param statements. But it turns out later that it is better to make GP models as nn.Modules for simplicity.

Why are we not using the samples from the latent variable f

In some situation, we can compute p(y|loc,var) exactly (e.g. when the likelihood is Gaussian). In those cases, a sample f is not important. For other cases, likelihood(f_loc, f_var, self.y) will call Normal(f_loc, f_var.sqrt()) to get a sample f, which is similar to what you want. In addition, there are other situations:

  • we would like to estimate likelihood(f_loc, f_var, self.y) using many samples f (not just one)
  • we want to use Gaussian Hermite quadrature to approximate likelihood(f_loc, f_var, self.y).

All those situations require user build custom likelihoods, rather change VGP model, guide.

Note also that sample('f', ...) is used to obtain Monte Carlo estimate of KL(q(f) || p(f)). If you want to compute exact KL, you should use Trace MeanField ELBO instead of Trace_ELBO when defining loss_fn.

Hi @fehiepsi, thanks! I agree that independently of the kind of likelihood used we have the desired behavior:

  • With Gaussian likelihood and conjugancy we have analytically a Gaussian posterior

  • With Non-Gaussian likelihood we are still going to approximate our (potentially non gaussian) posterior with a Normal(f_loc, f_var.sqrt())

However I feel I didn’t completely understand the part where you say

(in particular not why we would like to do that, but rather how the current setting allows you to do that).

Again, thank you for your help :wink:

I think we can just create a likelihood for this purpose, for example, I think we can implement Binary likelihood with 100 samples of f as

class Binary(Likelihood):
    def forward(self, f_loc, f_var, y=None):
        f = dist.Normal(f_loc, f_var.sqrt()).rsample(sample_shape=torch.Size([100]))
        y_dist = dist.Bernoulli(logits=f).to_event(f.dim())
        with pyro.poutine.scale(scale=0.01):
            return pyro.sample("y", y_dist, obs=y)

That is we get 100 fs, compute 100 corresponding likelihoods p(y | f), then takes the mean (using poutine.scale).

Ok I see, thanks!

However I am wondering, would we be able to run an MCMC sampling on the VariationalGP model to obtain posterior samples of f rather than a SVI-approach where we optimize the parameters of the variational distribution f_loc and f_var?

That’s a great question! I believe that you have to create a child class of VGP with a new method mcmodel, where you use the sampled f

...
f = pyro.sample("f", dist.MultivariateNormal(zero_loc, scale_tril=Lff))
return pyro.sample("y", dist.Bernoulli(logits=f), obs=y)

The current model implementation won’t work for VGP/VSGP classes. To make HMC stable, it is better to use whiten version

...
whiten_f = pyro.sample("whiten_f", dist.MultivariateNormal(zero_loc, scale_tril=identity))
f = Lff.matmul(whiten_f.unsqueeze(-1)).squeeze(-1)
return pyro.sample("y", dist.Bernoulli(logits=f), obs=y)