Hey guys,
I was working to try and integrate pyro and gpytorch and create a latent gaussian process model that works (with the hope of using it in other applications). I have read latest papers on how to do sparse gaussian latent processes, and believe this code should work, but I’m missing something. I have created code for deep gaussian latent processes from scratch using pyro (that works), so I do understand it pretty well, but I’m wanting to integrate gpytorch for more scalability.
In doing so, I believe this code should work, but for some reason it is not. If someone was able to see what the issue was, I’d be pretty impressed. Any help would be much appreciated. The dataset is the top 2000 samples from MNIST:
import matplotlib.pylab as plt
import torch
import os
import numpy as np
from pathlib import Path
import tensorflow as tf
import pyro
from torch.distributions.kl import kl_divergence
import gpytorch
from gpytorch.models.gplvm.latent_variable import *
from gpytorch.models.gplvm.bayesian_gplvm import BayesianGPLVM
from matplotlib import pyplot as plt
from tqdm.notebook import trange
from gpytorch.means import ZeroMean, ConstantMean, LinearMean
from gpytorch.mlls import VariationalELBO
from gpytorch.priors import NormalPrior
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.variational import VariationalStrategy
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal
from gpytorch.models import ApproximateGP, GP
from gpytorch.mlls import VariationalELBO, AddedLossTerm
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.models.deep_gps import DeepGPLayer, DeepGP
from gpytorch.mlls import DeepApproximateMLL
from sklearn.manifold import TSNE
def random_zero_grid(tensor, grid_size=9):
N, H, W = tensor.shape
tensor = torch.tensor(tensor)
mask = torch.ones_like(torch.tensor(tensor),dtype=torch.int64)
for n in range(N):
# Randomly choose the top-left corner of the 5x5 grid
top_left_x = np.random.randint(0, H - grid_size + 1)
top_left_y = np.random.randint(0, W - grid_size + 1)
# Set the 5x5 grid to 0 in the tensor
tensor[n, top_left_x:top_left_x + grid_size, top_left_y:top_left_y + grid_size] = 0
# Update the mask
mask[n, top_left_x:top_left_x + grid_size, top_left_y:top_left_y + grid_size] = 0
return tensor, mask
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
N = 2000
small_x_train = x_train[:N, ...].astype(np.float64) / 256.
small_y_train = y_train[:N]
N_missing = 500
x_missing_train_init = x_train[N:N+N_missing, ...].astype(np.float64) / 256.
x_missing_train, x_missing_mask = random_zero_grid(x_missing_train_init)
y_missing_train = y_train[N:N+N_missing]
observations_ = torch.tensor(small_x_train.reshape(N, -1).transpose())
obs_missing = x_missing_train.reshape(N_missing, -1).transpose(-1,-2)
mask_missing = x_missing_mask.reshape(N_missing, -1).transpose(-2,-1)
class ToyDeepGPHiddenLayer(DeepGPLayer):
def __init__(self, input_dims, output_dims, num_inducing=30, mean_type='constant'):
if output_dims is None:
inducing_points = torch.randn(num_inducing, input_dims)
batch_shape = torch.Size([])
else:
inducing_points = torch.randn(output_dims, num_inducing, input_dims)
batch_shape = torch.Size([output_dims])
variational_distribution = CholeskyVariationalDistribution(
num_inducing_points=num_inducing,
batch_shape=batch_shape
)
variational_strategy = VariationalStrategy(
self,
inducing_points,
variational_distribution,
learn_inducing_locations=True
)
super(ToyDeepGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)
self.mean_module = ZeroMean(ard_num_dims=input_dims)
self.covar_module = ScaleKernel(RBFKernel(ard_num_dims=input_dims))
#if mean_type == 'constant':
# self.mean_module = ConstantMean(batch_shape=batch_shape)
#else:
# self.mean_module = LinearMean(input_dims)
#self.covar_module = ScaleKernel(
# RBFKernel(batch_shape=batch_shape, ard_num_dims=input_dims),
# batch_shape=batch_shape, ard_num_dims=None
#)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
def __call__(self, x, *other_inputs, **kwargs):
"""
Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections
easily. For example, hidden_layer2(hidden_layer1_outputs, inputs) will pass the concatenation of the first
hidden layer's outputs and the input data to hidden_layer2.
"""
if len(other_inputs):
if isinstance(x, gpytorch.distributions.MultitaskMultivariateNormal):
x = x.rsample()
processed_inputs = [
inp.unsqueeze(0).expand(gpytorch.settings.num_likelihood_samples.value(), *inp.shape)
for inp in other_inputs
]
x = torch.cat([x] + processed_inputs, dim=-1)
#return super().__call__(x, are_samples=bool(len(other_inputs)))
return super().__call__(x, are_samples=True)
num_hidden_dims = 10
num_output_dims = 28**2
class Latent_DeepGP(DeepGP):
def __init__(self, hidden_length, name_prefix="DeepGP"):
#hidden_layer = ToyDeepGPHiddenLayer(
# input_dims=hidden_length,
# output_dims=num_hidden_dims,
# mean_type='constant',
#)
#middle_layer = ToyDeepGPHiddenLayer(
# input_dims=hidden_layer.output_dims,
# output_dims=num_hidden_dims,
# mean_type='constant',
#)
#last_layer = ToyDeepGPHiddenLayer(
# input_dims=hidden_layer.output_dims,
# output_dims=num_output_dims,
# mean_type='constant',
#)
last_layer = ToyDeepGPHiddenLayer(
input_dims=hidden_length,
output_dims=num_output_dims,
mean_type='constant',
)
super().__init__()
#self.hidden_layer = hidden_layer
#self.middle_layer = middle_layer
self.last_layer = last_layer
self.layer_list = [
#self.hidden_layer,
#self.middle_layer,
self.last_layer,
]
#self.likelihood = GaussianLikelihood()
self.name_prefix = name_prefix
def forward(self, inputs):
hidden_rep1 = self.hidden_layer(inputs)
hidden_rep2 = self.hidden_layer(hidden_rep1)
output = self.last_layer(hidden_rep2)
return output
def guide(self, y):
# Get q(f) - variational (guide) distribution of latent function
latent_loc = pyro.param('Z_loc',torch.randn(N,num_hidden_dims))
latent_scale = pyro.param('Z_scale',torch.randn(N,num_hidden_dims))
latent_vals = pyro.sample('Z_val',
pyro.distributions.Normal(loc=latent_loc,scale=torch.exp(latent_scale)).to_event(2).mask(False)
)
output = latent_vals
for i, layer in enumerate(self.layer_list):
function_dist = layer.pyro_guide(output,name_prefix=self.name_prefix + f'_{i}')
with pyro.plate(self.name_prefix + f".data_plate_{i}", dim=-1):
# Sample from latent function distribution
output = pyro.sample(self.name_prefix + f".f(x)_{i}", function_dist)
def model(self, y):
pyro.module(self.name_prefix + ".gp", self)
prior = pyro.distributions.Normal(loc=torch.zeros(N,num_hidden_dims),scale=torch.ones(N,num_hidden_dims)).to_event(2)
latent_vals = pyro.sample('Z_val',
prior.mask(False)
)
latent_loc = pyro.param('Z_loc',torch.randn(N,num_hidden_dims))
latent_scale = pyro.param('Z_scale',torch.randn(N,num_hidden_dims))
posterior = pyro.distributions.Normal(loc=latent_loc,scale=torch.exp(latent_scale)).to_event(2)
pyro.factor("Z_kl_div",-kl_divergence(posterior,prior))
# Use a plate here to mark conditional independencies
output = latent_vals
for i, layer in enumerate(self.layer_list):
function_dist = layer.pyro_model(output, name_prefix=self.name_prefix + f'_{i}')
with pyro.plate(self.name_prefix + f".data_plate_{i}", dim=-1):
# Sample from latent function distribution
output = pyro.sample(self.name_prefix + f".f(x)_{i}", function_dist)
output_obs_error = pyro.param('output_obs_error',torch.zeros(num_output_dims))
# Sample from observed distribution
y_vals = pyro.sample(
self.name_prefix + ".y",
pyro.distributions.Normal(loc=output,scale=torch.exp(output_obs_error)).to_event(2), # rate = 1 / scale
obs=y
)
def plot_results(embed_before, y_train, N):
tsne = TSNE(n_components=2, random_state=42)
embed_before_2d = tsne.fit_transform(embed_before)
embed_after_2d = tsne.fit_transform(embed_after)
# Plot the latent locations before and after training
plt.figure(figsize=(7, 7))
plt.title("After training")
plt.grid(False)
plt.scatter(x=embed_before_2d[:, 0], y=embed_before_2d[:, 1],
c=y_train[:N], cmap=plt.get_cmap('Paired'), s=50)
plt.show()
def main1():
data = torch.tensor(small_x_train.reshape(N, -1))
model = Latent_DeepGP(num_hidden_dims)
optimizer = pyro.optim.ClippedAdam({'lr': 5e-2,'clip_norm':10.0})
elbo = pyro.infer.Trace_ELBO(retain_graph=True)
svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)
model.train()
log_interval = 10
iterator = range(10000)
for i in iterator:
model.zero_grad()
loss = svi.step(data)
if i % log_interval == 0:
print(f'Iteration {i}, loss: {loss}')
lat_embed = pyro.param('Z_loc').detach().numpy()
plot_results(lat_embed,
small_y_train,
N
)
print()
if __name__ == '__main__':
main1()