Multiple sparse, latent, partially observed Gaussian Processes with variational mini-batch inferencing in the context of a mixed effects model

I’m seeking advice about how to go about building a mixed effects model that includes multiple sparse, latent, partially observed Gaussian Processes. To fit the model, I’d like to do variational inference with mini-batches. Since I’m relatively new to all of this, I’ll spell out why I think that I need each of the above characteristics in case I’m mistaken about how each term is used.

By “mixed effects” (could be abusing the term here), I just mean that the GPs will be part of the equation that is used to predict the observable, and the sampled values of the GP will be combined with the values of other random variables drawn from other distributions to predict the values of the actual observable.

By “multiple” I simply mean that I need to have more than one GP. These could potentially be combined into a single “multitask” GP, but I think that it is fine to keep them independent from each other for my present purposes.

By “sparse” I mean that I want to save computational cost by sampling the GPs at a pre-defined set of inducing points. In fact, in my application, observations recur at a discrete set of points that saturates as the dataset grows, so there’s really no need to have a new X for every datapoint.

By “latent” I mean that my observations are not directly on the value of the GP. Instead, I will be using the sampled values in an equation to predict the observable, as described above.

By “partially observed” I mean that, for each GP, I know what the value of the GP should be at a particular X=X_reference (one of the inducing points). Setting these values (one per GP) is necessary to achieve identifiability of the model.

Since my datasets are large, I want to perform variational inference and, ideally, support mini-batch inferencing.

With all of that being said, I have questions/am seeking input on the following points:

  1. At a high level, is there reason to believe that the pyro.contrib.gp module would be preferable to the Pyro/GPytorch integration for building this model or vice versa?
  2. To perform the partial observation, would it be sufficient to first sample from the GP and then subtract off the difference between the originally sampled value at X_reference and the desired refernce value before using it to predict the actual observable?
  3. Is it still true that pyro.contrib.gp doesn’t support mini-batch inferencing? Some of the older forum posts indicate that this is the case, but some of the documentation seems to indicate otherwise.
  4. The code example that I think comes closest to doing what I want to do, I think, is this: https://docs.gpytorch.ai/en/stable/examples/07_Pyro_Integration/Pyro_GPyTorch_Low_Level.html Do you know of any other examples that are even closer to what I’ve described above?
  5. Based on your experience and my description above, what traps/complications might be waiting for me that I haven’t mentioned yet in this post?

As always, I appreciate any and all advice that you can take the time to share.

I don’t have much experience with GPyTorch but I think it is better suited for your usage case because pyro.contrib.gp does not support learning independent GPs yet. What pyro.contrib.gp supports are:

  • multi-task GP (with task shapes can be arbitrary)
  • minibatch training (see deep kernel learning example or deep gp tutorial)

If independent-GP is not required, then I guess you can use contrib.gp:

  • using obs_mask at sample primitive for partial observation. More flexible solutions can be seen in this topic.
  • the latent variable output can be used for other computation, as in deep gp tutorial

Thanks, @fehiepsi. A multitask GP would actually probably be even better for my purposes, so I’m definitely interested in giving contrib.gp a try. Can you write some pseudocode to demonstrate, roughly, what it would look like to create a multitask GP using contrib.gp and then to use the obs_mask argument to sample from that GP in the context of a pyro model?

For multitask GP, you can set latent_shape argument to the desired shape. For example, in the multiclass case of dkl example, we set latent_shape=(10,) because we need 10 GPs to predict the logits of each classified category. For the usage of obs_mask, I guess you can do (please put some thoughts here, I just use my intuition)

gp.set_data(gp.X, None)
y_loc, y_var = gp.model()
y_dist = dist.Normal(y_loc, y_var.sqrt())
y_draw = y_dist()  # stochastic draw - as in doubly stochastic GP paper 
y = torch.where(mask, y_obs, y_draw)
pyro.sample("y_obs", y_dist.mask(mask), obs=y_obs)
# alternatively, we can do
# pyro.sample("y_obs", dist.Normal(y_loc[mask], y_std[mask]), obs=y_obs[mask])

where y_obs is a tensor with known values at mask entries.

This is great @fehiepsi; thanks so much. I’m on now to the part of constructing a guide for my hybrid model. What’s the best and/or simplest way to do so? Something like …

def model(...):
  gpmodel.set_data(gpmodel.Xu, None)
  gpmodel_loc, gpmodel_var = gpmodel.model()
  gpmodel_dist = dist.Normal(gpmodel_loc, gpmodel_var.sqrt())
  gpmodel_draw = gpmodel_dist()

  # some additional pyro.param and pyro.sample statements ...
  example_param = pyro.param('example_param', ...)
  example_sample = pyro.sample('example_sample', <use gpmodel_draw somehow>, obs=obs)

def guide(...):
  AutoDelta(model) #?
  gpmodel.guide() #?

I guess you can use AutoDelta with pyro.poutine.block(expose=[site_a_name, site_b_name,...]) for simplicity. But I think explicitly declaring the corresponding sample statements in guide will help you in debugging.

Thanks again for your help @fehiepsi. I think that I’ve got the guide figured out. I’m still struggling with fixing the value of the GP at one of the inducing points, however. I think that this amounts to observing one dimension of a mutli-dimensional distribution as mentioned here (https://github.com/pyro-ppl/pyro/issues/166) by @fritzo. Is obs_mask still the way to go here or is something more involved required?

I just realized that my last question doesn’t make sense in the context of this thread. I’m currently working on an implementation of the model with GPytorch GPs. I’m going to try again with contrib.gp now that I think I understand better what @fehiepsi was suggesting.