I was wondering about what the correct way is to use the gp.kernels
inside a Pyro model without the GPModel
. As a MWE I currently have this adapted from the GP tutorial:
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
from pyro.infer.mcmc.util import initialize_model
from pyro import poutine
pyro.set_rng_seed(0)
# Sample data
N = 20
X = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))
y = 0.5 * torch.sin(3 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(N,))
def plot(
plot_observed_data=False,
plot_predictions=False,
n_prior_samples=0,
model=None,
kernel=None,
n_test=500,
):
plt.figure(figsize=(12, 6))
if plot_observed_data:
plt.plot(X.numpy(), y.numpy(), "kx")
if plot_predictions:
Xtest = torch.linspace(-0.5, 5.5, n_test) # test inputs
# compute predictive mean and variance
with torch.no_grad():
if type(model) == gp.models.VariationalSparseGP:
mean, cov = model(Xtest, full_cov=True)
else:
mean, cov = model(Xtest, full_cov=True, noiseless=False)
sd = cov.diag().sqrt() # standard deviation at each input point x
plt.plot(Xtest.numpy(), mean.numpy(), "r", lw=2) # plot the mean
plt.fill_between(
Xtest.numpy(), # plot the two-sigma uncertainty about the mean
(mean - 2.0 * sd).numpy(),
(mean + 2.0 * sd).numpy(),
color="C0",
alpha=0.3,
)
if n_prior_samples > 0: # plot samples from the GP prior
Xtest = torch.linspace(-0.5, 5.5, n_test) # test inputs
noise = (
model.noise
if type(model) != gp.models.VariationalSparseGP
else model.likelihood.variance
)
cov = kernel.forward(Xtest) + noise.expand(n_test).diag()
samples = dist.MultivariateNormal(
torch.zeros(n_test), covariance_matrix=cov
).sample(sample_shape=(n_prior_samples,))
plt.plot(Xtest.numpy(), samples.numpy().T, lw=2, alpha=0.4)
plt.xlim(-0.5, 5.5)
def model(jitter=1e-6):
kernel = gp.kernels.Sum(
gp.kernels.RBF(input_dim=1), gp.kernels.Periodic(input_dim=1)
)
kernel.kern0.lengthscale = pyro.nn.PyroSample(dist.InverseGamma(2.0, 1.0))
kernel.kern1.lengthscale = pyro.nn.PyroSample(dist.InverseGamma(2.0, 1.0))
kernel.kern1.period = pyro.nn.PyroSample(dist.InverseGamma(2.0, 1.0))
std = pyro.sample("std", dist.HalfNormal(1))
# Create covariance matrix
N = X.size(0)
Kff = kernel(X)
Kff.view(-1)[:: N + 1] += jitter + torch.pow(std, 2) # add noise to diagonal
Lff = torch.linalg.cholesky(Kff)
zero_loc = X.new_zeros(N)
pyro.sample("y", dist.MultivariateNormal(zero_loc, scale_tril=Lff), obs=y)
return kernel
def main():
num_samples = 50
num_warmup_steps = 50
plot(plot_observed_data=True)
init_params, potential_fn, transforms, _ = initialize_model(
model, num_chains=1, jit_compile=False
)
nuts_kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(
nuts_kernel,
num_samples=num_samples,
warmup_steps=num_warmup_steps,
num_chains=1,
initial_params=init_params,
transforms=transforms,
)
mcmc.run()
samples = mcmc.get_samples()
post_kernel_fns = [
poutine.condition(model, data={k: samples[k][ix] for k in samples.keys()})()
for ix in range(num_samples)
]
assert post_kernel_fns[0].kern0.lengthscale == samples["kern0.lengthscale"][0]
if __name__ == "__main__":
main()
The assertion at the end evaluates to False
so I wondered whether there is a programmatic way to extract posterior kernel functions?
My end goal is to use those posterior kernel functions to calculate the posterior predictive of the model so any examples on how to get the posterior predictive from the MCMC samples would be appreciated!
Currently I am manually setting the hyperparameters of the kernel after inference and then I am using this function:
def gp_analytic_posterior(
kernel_fn,
X,
new_xs,
y,
noise,
jitter,
full_cov=False,
):
N = X.size(0)
Kff = kernel_fn(X).contiguous()
Kff = Kff.type(X.dtype).clone()
Kff.view(-1)[:: N + 1] += jitter + torch.pow(noise, 2)
Lff = torch.linalg.cholesky(Kff)
gp_post_mean, gp_post_cov = gp.util.conditional(
new_xs, X, kernel_fn, y, Lff=Lff, jitter=jitter, full_cov=full_cov
)
if full_cov:
M = new_xs.size(0)
gp_post_cov = gp_post_cov.contiguous()
gp_post_cov.view(-1, M * M)[:, :: M + 1] += torch.pow(noise, 2)
else:
gp_post_cov = gp_post_cov + torch.pow(noise, 2)
return gp_post_mean, gp_post_cov
which is based of the code in the GPRegression module.
Thanks for the help!