Hi,
I am using pyro for a project for sampling hyperparameter posteriors with gpytorch, a GP package built on top of pytorch. However, when i was testing a simple example as following, the potential function buit in pyro seems not to return consistents function value/grad. Replicate code:
import torch
import gpytorch
import pyro
from pyro.infer.mcmc import NUTS, MCMC, HMC
from matplotlib import pyplot as plt
import numpy as np
from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
train_x = torch.linspace(0, 1, 10).double()
train_y = torch.sin(train_x * (2 * math.pi)).double() + torch.randn(train_x.size()).double() * 0.1
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
model = ExactGPModel(train_x, train_y, likelihood)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
model.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
model.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.0, 9.0), "lengthscale")
model.covar_module.register_prior("outputscale_prior", UniformPrior(0, 4), "outputscale")
likelihood.register_prior("noise_prior", UniformPrior(0.0, 0.25), "noise")
model.double()
likelihood.double()
# define pyro primitive
def pyro_model(x, y):
model.pyro_sample_from_prior()
output = model(x)
with gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False):
loss = mll(output, y)*y.shape[0]
pyro.factor("gp_mll", loss)
# initialize model parameters
model.mean_module.constant.data.fill_(0.0)
model.covar_module.outputscale = 0.5**2
model.covar_module.base_kernel.lengthscale = 1
model.likelihood.noise = 0.05**2
initial_params = {'mean_module.mean_prior': model.mean_module.constant.detach(),\
'covar_module.base_kernel.lengthscale_prior':model.covar_module.base_kernel.raw_lengthscale.detach(),\
'covar_module.outputscale_prior': model.covar_module.raw_outputscale.detach(),\
'likelihood.noise_prior': model.likelihood.raw_noise.detach()}
# define nuts and set up
args = (train_x, train_y)
kwargs = {}
nuts_kernel = NUTS(pyro_model)
nuts_kernel.setup(0, *args, **kwargs)
for _ in range(3):
grads, v = pyro.ops.integrator.potential_grad(nuts_kernel.potential_fn, initial_params)
print(grads["mean_module.mean_prior"])
print(v)
I got something like:
tensor([-0.0004], dtype=torch.float64)
tensor(6153.4552, dtype=torch.float64)
tensor([-0.0002], dtype=torch.float64)
tensor(24555.3200, dtype=torch.float64)
tensor([-0.0001], dtype=torch.float64)
tensor(98107.9468, dtype=torch.float64)
I would be very appreciated if anyone knows what happened and how I could solve this issue.