Using gp.kernels inside a Pyro model

I was wondering about what the correct way is to use the gp.kernels inside a Pyro model without the GPModel. As a MWE I currently have this adapted from the GP tutorial:

import matplotlib.pyplot as plt
import torch

import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
from pyro.infer.mcmc.util import initialize_model
from pyro import poutine

pyro.set_rng_seed(0)

# Sample data
N = 20
X = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))
y = 0.5 * torch.sin(3 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(N,))


def plot(
    plot_observed_data=False,
    plot_predictions=False,
    n_prior_samples=0,
    model=None,
    kernel=None,
    n_test=500,
):

    plt.figure(figsize=(12, 6))
    if plot_observed_data:
        plt.plot(X.numpy(), y.numpy(), "kx")
    if plot_predictions:
        Xtest = torch.linspace(-0.5, 5.5, n_test)  # test inputs
        # compute predictive mean and variance
        with torch.no_grad():
            if type(model) == gp.models.VariationalSparseGP:
                mean, cov = model(Xtest, full_cov=True)
            else:
                mean, cov = model(Xtest, full_cov=True, noiseless=False)
        sd = cov.diag().sqrt()  # standard deviation at each input point x
        plt.plot(Xtest.numpy(), mean.numpy(), "r", lw=2)  # plot the mean
        plt.fill_between(
            Xtest.numpy(),  # plot the two-sigma uncertainty about the mean
            (mean - 2.0 * sd).numpy(),
            (mean + 2.0 * sd).numpy(),
            color="C0",
            alpha=0.3,
        )
    if n_prior_samples > 0:  # plot samples from the GP prior
        Xtest = torch.linspace(-0.5, 5.5, n_test)  # test inputs
        noise = (
            model.noise
            if type(model) != gp.models.VariationalSparseGP
            else model.likelihood.variance
        )
        cov = kernel.forward(Xtest) + noise.expand(n_test).diag()
        samples = dist.MultivariateNormal(
            torch.zeros(n_test), covariance_matrix=cov
        ).sample(sample_shape=(n_prior_samples,))
        plt.plot(Xtest.numpy(), samples.numpy().T, lw=2, alpha=0.4)

    plt.xlim(-0.5, 5.5)


def model(jitter=1e-6):
    kernel = gp.kernels.Sum(
        gp.kernels.RBF(input_dim=1), gp.kernels.Periodic(input_dim=1)
    )
    kernel.kern0.lengthscale = pyro.nn.PyroSample(dist.InverseGamma(2.0, 1.0))
    kernel.kern1.lengthscale = pyro.nn.PyroSample(dist.InverseGamma(2.0, 1.0))
    kernel.kern1.period = pyro.nn.PyroSample(dist.InverseGamma(2.0, 1.0))

    std = pyro.sample("std", dist.HalfNormal(1))

    # Create covariance matrix
    N = X.size(0)
    Kff = kernel(X)
    Kff.view(-1)[:: N + 1] += jitter + torch.pow(std, 2)  # add noise to diagonal
    Lff = torch.linalg.cholesky(Kff)

    zero_loc = X.new_zeros(N)
    pyro.sample("y", dist.MultivariateNormal(zero_loc, scale_tril=Lff), obs=y)
    return kernel


def main():
    num_samples = 50
    num_warmup_steps = 50
    plot(plot_observed_data=True)

    init_params, potential_fn, transforms, _ = initialize_model(
        model, num_chains=1, jit_compile=False
    )
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(
        nuts_kernel,
        num_samples=num_samples,
        warmup_steps=num_warmup_steps,
        num_chains=1,
        initial_params=init_params,
        transforms=transforms,
    )
    mcmc.run()

    samples = mcmc.get_samples()
    post_kernel_fns = [
        poutine.condition(model, data={k: samples[k][ix] for k in samples.keys()})()
        for ix in range(num_samples)
    ]
    assert post_kernel_fns[0].kern0.lengthscale == samples["kern0.lengthscale"][0]


if __name__ == "__main__":
    main()

The assertion at the end evaluates to False so I wondered whether there is a programmatic way to extract posterior kernel functions?

My end goal is to use those posterior kernel functions to calculate the posterior predictive of the model so any examples on how to get the posterior predictive from the MCMC samples would be appreciated!

Currently I am manually setting the hyperparameters of the kernel after inference and then I am using this function:

def gp_analytic_posterior(
    kernel_fn,
    X,
    new_xs,
    y,
    noise,
    jitter,
    full_cov=False,
):
    N = X.size(0)
    Kff = kernel_fn(X).contiguous()
    Kff = Kff.type(X.dtype).clone()
    Kff.view(-1)[:: N + 1] += jitter + torch.pow(noise, 2)
    Lff = torch.linalg.cholesky(Kff)

    gp_post_mean, gp_post_cov = gp.util.conditional(
        new_xs, X, kernel_fn, y, Lff=Lff, jitter=jitter, full_cov=full_cov
    )
    if full_cov:
        M = new_xs.size(0)
        gp_post_cov = gp_post_cov.contiguous()
        gp_post_cov.view(-1, M * M)[:, :: M + 1] += torch.pow(noise, 2)
    else:
        gp_post_cov = gp_post_cov + torch.pow(noise, 2)
    return gp_post_mean, gp_post_cov

which is based of the code in the GPRegression module.

Thanks for the help!

1 Like

I have found that this method seems to work:

def extract_posterior_kernels(posterior_samples: list[pyro.poutine.Trace]):
      post_kernels = [trace.nodes["_RETURN"]["value"] for trace in posterior_samples]
      for ix in range(len(post_kernels)):
          for name, s in posterior_samples[ix].iter_stochastic_nodes():
              if name == "std" or name == "y":
                    continue

              if isinstance(post_kernels[ix], gp.kernels.Sum) or isinstance(
                  post_kernels[ix], gp.kernels.Product
              ):
                  names = name.split(".")
                  kern_mod = post_kernels[ix]._modules[names[0]]
                  for jx in range(len(names) - 2):
                      kern_mod = kern_mod._modules[names[jx + 1]]
                  setattr(kern_mod, names[-1], s["value"])
              else:
                  setattr(post_kernels[ix], name, s["value"])
      return post_kernels

There is probably a cleaner way to do this. But I found this works reliably for a large variety of kernels.