GPyTorch and Pyro Integration - Latent Gaussian Process Model

Hey guys,

I was working to try and integrate pyro and gpytorch and create a latent gaussian process model that works (with the hope of using it in other applications). I have read latest papers on how to do sparse gaussian latent processes, and believe this code should work, but I’m missing something. I have created code for deep gaussian latent processes from scratch using pyro (that works), so I do understand it pretty well, but I’m wanting to integrate gpytorch for more scalability.

In doing so, I believe this code should work, but for some reason it is not. If someone was able to see what the issue was, I’d be pretty impressed. Any help would be much appreciated. The dataset is the top 2000 samples from MNIST:

import matplotlib.pylab as plt
import torch
import os
import numpy as np
from pathlib import Path
import tensorflow as tf

import pyro

from torch.distributions.kl import kl_divergence

import gpytorch
from gpytorch.models.gplvm.latent_variable import *
from gpytorch.models.gplvm.bayesian_gplvm import BayesianGPLVM
from matplotlib import pyplot as plt
from tqdm.notebook import trange
from gpytorch.means import ZeroMean, ConstantMean, LinearMean
from gpytorch.mlls import VariationalELBO
from gpytorch.priors import NormalPrior
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.variational import VariationalStrategy
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal

from gpytorch.models import ApproximateGP, GP
from gpytorch.mlls import VariationalELBO, AddedLossTerm
from gpytorch.likelihoods import GaussianLikelihood

from gpytorch.models.deep_gps import DeepGPLayer, DeepGP
from gpytorch.mlls import DeepApproximateMLL

from sklearn.manifold import TSNE

def random_zero_grid(tensor, grid_size=9):
    N, H, W = tensor.shape
    tensor = torch.tensor(tensor)
    mask = torch.ones_like(torch.tensor(tensor),dtype=torch.int64)
    
    for n in range(N):
        # Randomly choose the top-left corner of the 5x5 grid
        top_left_x = np.random.randint(0, H - grid_size + 1)
        top_left_y = np.random.randint(0, W - grid_size + 1)
        
        # Set the 5x5 grid to 0 in the tensor
        tensor[n, top_left_x:top_left_x + grid_size, top_left_y:top_left_y + grid_size] = 0
        
        # Update the mask
        mask[n, top_left_x:top_left_x + grid_size, top_left_y:top_left_y + grid_size] = 0
    
    return tensor, mask

(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
N = 2000
small_x_train = x_train[:N, ...].astype(np.float64) / 256.
small_y_train = y_train[:N]

N_missing = 500
x_missing_train_init = x_train[N:N+N_missing, ...].astype(np.float64) / 256.
x_missing_train, x_missing_mask = random_zero_grid(x_missing_train_init)
y_missing_train = y_train[N:N+N_missing]


observations_ = torch.tensor(small_x_train.reshape(N, -1).transpose())

obs_missing = x_missing_train.reshape(N_missing, -1).transpose(-1,-2)
mask_missing = x_missing_mask.reshape(N_missing, -1).transpose(-2,-1)

class ToyDeepGPHiddenLayer(DeepGPLayer):
    def __init__(self, input_dims, output_dims, num_inducing=30, mean_type='constant'):
        if output_dims is None:
            inducing_points = torch.randn(num_inducing, input_dims)
            batch_shape = torch.Size([])
        else:
            inducing_points = torch.randn(output_dims, num_inducing, input_dims)
            batch_shape = torch.Size([output_dims])

        variational_distribution = CholeskyVariationalDistribution(
            num_inducing_points=num_inducing,
            batch_shape=batch_shape
        )

        variational_strategy = VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=True
        )

        super(ToyDeepGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)

        self.mean_module = ZeroMean(ard_num_dims=input_dims)
        self.covar_module = ScaleKernel(RBFKernel(ard_num_dims=input_dims))

        #if mean_type == 'constant':
        #    self.mean_module = ConstantMean(batch_shape=batch_shape)
        #else:
        #    self.mean_module = LinearMean(input_dims)
        #self.covar_module = ScaleKernel(
        #    RBFKernel(batch_shape=batch_shape, ard_num_dims=input_dims),
        #    batch_shape=batch_shape, ard_num_dims=None
        #)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

    def __call__(self, x, *other_inputs, **kwargs):
        """
        Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections
        easily. For example, hidden_layer2(hidden_layer1_outputs, inputs) will pass the concatenation of the first
        hidden layer's outputs and the input data to hidden_layer2.
        """
        if len(other_inputs):
            if isinstance(x, gpytorch.distributions.MultitaskMultivariateNormal):
                x = x.rsample()

            processed_inputs = [
                inp.unsqueeze(0).expand(gpytorch.settings.num_likelihood_samples.value(), *inp.shape)
                for inp in other_inputs
            ]

            x = torch.cat([x] + processed_inputs, dim=-1)

        #return super().__call__(x, are_samples=bool(len(other_inputs)))
        return super().__call__(x, are_samples=True)

num_hidden_dims = 10
num_output_dims = 28**2

class Latent_DeepGP(DeepGP):
    def __init__(self, hidden_length, name_prefix="DeepGP"):

        #hidden_layer = ToyDeepGPHiddenLayer(
        #    input_dims=hidden_length,
        #    output_dims=num_hidden_dims,
        #    mean_type='constant',
        #)

        #middle_layer = ToyDeepGPHiddenLayer(
        #    input_dims=hidden_layer.output_dims,
        #    output_dims=num_hidden_dims,
        #    mean_type='constant',
        #)

        #last_layer = ToyDeepGPHiddenLayer(
        #    input_dims=hidden_layer.output_dims,
        #    output_dims=num_output_dims,
        #    mean_type='constant',
        #)

        last_layer = ToyDeepGPHiddenLayer(
            input_dims=hidden_length,
            output_dims=num_output_dims,
            mean_type='constant',
        )

        super().__init__()

        #self.hidden_layer = hidden_layer
        #self.middle_layer = middle_layer
        self.last_layer = last_layer

        self.layer_list = [
            #self.hidden_layer,
            #self.middle_layer,
            self.last_layer,
            ]
        #self.likelihood = GaussianLikelihood()
        self.name_prefix = name_prefix

    def forward(self, inputs):
        hidden_rep1 = self.hidden_layer(inputs)
        hidden_rep2 = self.hidden_layer(hidden_rep1)
        output = self.last_layer(hidden_rep2)
        return output

    def guide(self, y):
        # Get q(f) - variational (guide) distribution of latent function
        latent_loc = pyro.param('Z_loc',torch.randn(N,num_hidden_dims))
        latent_scale = pyro.param('Z_scale',torch.randn(N,num_hidden_dims))
        latent_vals = pyro.sample('Z_val',
                                  
        pyro.distributions.Normal(loc=latent_loc,scale=torch.exp(latent_scale)).to_event(2).mask(False)
                                  )
        

        output = latent_vals
        for i, layer in enumerate(self.layer_list):
            function_dist = layer.pyro_guide(output,name_prefix=self.name_prefix + f'_{i}')

            with pyro.plate(self.name_prefix + f".data_plate_{i}", dim=-1):
                # Sample from latent function distribution
                output = pyro.sample(self.name_prefix + f".f(x)_{i}", function_dist)
        

    def model(self, y):
        pyro.module(self.name_prefix + ".gp", self)

        prior = pyro.distributions.Normal(loc=torch.zeros(N,num_hidden_dims),scale=torch.ones(N,num_hidden_dims)).to_event(2)
        latent_vals = pyro.sample('Z_val',
                                  prior.mask(False)
                                  )

        latent_loc = pyro.param('Z_loc',torch.randn(N,num_hidden_dims))
        latent_scale = pyro.param('Z_scale',torch.randn(N,num_hidden_dims))
        posterior = pyro.distributions.Normal(loc=latent_loc,scale=torch.exp(latent_scale)).to_event(2)

        pyro.factor("Z_kl_div",-kl_divergence(posterior,prior))
        # Use a plate here to mark conditional independencies
        output = latent_vals
        for i, layer in enumerate(self.layer_list):

            function_dist = layer.pyro_model(output, name_prefix=self.name_prefix + f'_{i}')

            with pyro.plate(self.name_prefix + f".data_plate_{i}", dim=-1):
                # Sample from latent function distribution
                output = pyro.sample(self.name_prefix + f".f(x)_{i}", function_dist)

        
        output_obs_error = pyro.param('output_obs_error',torch.zeros(num_output_dims))
        # Sample from observed distribution
        y_vals =  pyro.sample(
            self.name_prefix + ".y",
            pyro.distributions.Normal(loc=output,scale=torch.exp(output_obs_error)).to_event(2),  # rate = 1 / scale
            obs=y
        )


def plot_results(embed_before, y_train, N):

        tsne = TSNE(n_components=2, random_state=42)
        embed_before_2d = tsne.fit_transform(embed_before)
        embed_after_2d = tsne.fit_transform(embed_after)
        # Plot the latent locations before and after training
        plt.figure(figsize=(7, 7))
        plt.title("After training")
        plt.grid(False)
        plt.scatter(x=embed_before_2d[:, 0], y=embed_before_2d[:, 1],
                   c=y_train[:N], cmap=plt.get_cmap('Paired'), s=50)
        plt.show()



def main1():

    data = torch.tensor(small_x_train.reshape(N, -1))
    model = Latent_DeepGP(num_hidden_dims)
    optimizer = pyro.optim.ClippedAdam({'lr': 5e-2,'clip_norm':10.0})
    elbo = pyro.infer.Trace_ELBO(retain_graph=True)
    svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)

    model.train()
    log_interval = 10
    iterator = range(10000)
    for i in iterator:

        model.zero_grad()
        loss = svi.step(data)

        if i % log_interval == 0:
            print(f'Iteration {i}, loss: {loss}')


    lat_embed = pyro.param('Z_loc').detach().numpy()
    plot_results(lat_embed, 
                small_y_train,
                N
            )

    print()


if __name__ == '__main__':

    main1()


have you seen this? Pyro Integration — GPyTorch 0.1.dev97+gf73fa7d documentation

no idea if your code is correct but it’s almost always a bad idea to initialize variational distributions to be broad (and thus lead to high variance elbos)

@martinjankowiak Thanks for the reply! Yes, I have seen that, and that is in-fact part of what I’m using in my code. I have created another model that uses only Pyro along with the initialization:

latent_scale = pyro.param('Z_scale',torch.randn(N,num_hidden_dims))

and it works fine, so I don’t believe that is the issue, although that is really good to know for the future. To improve the initialization, though, is the idea that I would correct it to something like this?

latent_scale = pyro.param('Z_scale',torch.zeros(N,num_hidden_dims)) # setting to 0 since I apply exp later

However, you’ve raised an issue of initialization that could be a problem. Even something like a prior of the form

pyro.distributions.Normal(loc=torch.zeros(N,num_hidden_dims),scale=torch.ones(N,num_hidden_dims)).to_event(2)

becomes higher variance the more latent dimensions you have, and could cause problems with the kernels since you are dotting two “num_hiddn_dims” vectors inside the kernel, this could cause problems as the num_hidden_dims increases. I made adjustments for this in my own code that I wrote, but I wonder if this is causing an issue in the GPyTorch. I could be way off, but I might test this out.

Again, thanks for the reply. If any other ideas come to mind on what the issue might be, I’d be curious to hear.

variances should be small. e.g.
pyro.param('Z_scale', (0.01 * torch.randn(N,num_hidden_dims) - 3).exp())

@martinjankowiak You are completely correct. Your suggestion fixed the issue. You’re the man! Thanks a ton. I would have never seen that.

There are still issues when adding multiple gp layers, but this suggestion fixed the code so it at least works with 1 layer.

So it turns out that did fix the core problem (martin’s suggestion). The remaining issue, however, is that deeper models, using GPyTorch train extremely slow (I built my own model that trains much faster). Not sure what is going on with their code to cause the training to be slow on deeper models, but it very clearly is the case.