Hi,
I’ve been trying to implement something based on this tutorial that integrates Pyro and GPyTorch. I’m working with RNA sequencing data, and I’m trying to do individual models for each gene (I’m sure there’s a more efficient way to be doing this, but I’m just trying something out at the moment). When I try to iterate over multiple genes, the first model will train correctly, but every subsequent model will completely fail to learn and just remain stuck at the loss that they start out with.
My first thought was that some bit of information was being held over between these iterations where it shouldn’t be, but I’ve tried deleting the model and clearing the cache at the end of each loop and nothing really changed. I’ve also tried choosing different optimizers and changing up the parameters, but that also hasn’t solved it. I feel like I’m missing something obvious, but I haven’t been able to figure it out, so hopefully someone else knows what I’m doing wrong. I’ve included some code that encapsulates the problem below.
import numpy as np
import torch
import gpytorch
import pyro
from tqdm import tqdm
train_x = torch.from_numpy(np.load('train_x.npy')).float()
train_y = torch.from_numpy(np.load('train_y.npy')).float()
class PVGPRegressionModel(gpytorch.models.ApproximateGP):
def __init__(self, num_inducing=128, name_prefix="gene_model"):
self.name_prefix = name_prefix
inducing_points = torch.linspace(xmin, xmax, num_inducing)
variational_strategy = gpytorch.variational.VariationalStrategy(
self, inducing_points,
gpytorch.variational.CholeskyVariationalDistribution(num_inducing_points=num_inducing)
)
super().__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean, covar)
def guide(self, x, y):
function_dist = self.pyro_guide(x)
with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
pyro.sample(self.name_prefix + ".f(x)", function_dist)
def model(self, x, y):
pyro.module(self.name_prefix + ".gp", self)
function_dist = self.pyro_model(x)
with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
function_samples = pyro.sample(self.name_prefix + ".f(x)", function_dist)
scale_samples = function_samples.exp()
return pyro.sample(
self.name_prefix + ".y",
pyro.distributions.Poisson(scale_samples),
obs=y
)
num_iter=200
num_particles=256
for gene in range(2,5):
path = 2
train_x_path = train_x[:, path]
train_y_gene = train_y[:, gene]
xmin = min(train_x_path.min().floor(), 0)
xmax = train_x_path.max().ceil()
#
model = PVGPRegressionModel()
optimizer = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO(num_particles=num_particles, vectorize_particles=True, retain_graph=True)
svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)
model.train()
iterator = tqdm(range(num_iter))
losses = np.zeros(num_iter)
for i in iterator:
model.zero_grad()
loss = svi.step(train_x_path, train_y_gene)
losses[i] = loss
iterator.set_postfix(loss=loss, lengthscale=model.covar_module.base_kernel.lengthscale.item())
print(losses)
del model, losses
torch.cuda.empty_cache()