Gaussian process with parametric mean function

Hi, I am running the Gaussian process tutorial of Pyro (https://pyro.ai/examples/gp.html).

Now, I add some bias to the generated data, by replacing the line

y = 0.5 * torch.sin(3*X) + dist.Normal(0.0, 0.2).sample(sample_shape=(N,))

with

y = 0.5 * torch.sin(3*X) + dist.Normal(0.0, 0.2).sample(sample_shape=(N,)) + 2.

Moreover, I replace the line that specifies the regression model

gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(1.))

with

gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(1.),mean_function=lambda _:pyro.param("meanGP", torch.tensor(0.0)))

My problem: the additional parameter I added to learn the constant mean is not learned during the inference. I suppose the optimizer does not ‘see’ this parameter.

I tried to modify the inference as follows:

pyro.clear_param_store()
adam_params = {"lr": 0.005, "betas": (0.95, 0.999)}
optimizer = pyro.optim.Adam(adam_params)
# setup the inference algorithm
svi = pyro.infer.SVI(gpr.model, gpr.guide, optimizer, loss=pyro.infer.Trace_ELBO())

n_steps = 1000
# do gradient steps
for step in range(n_steps):
    svi.step()

However, this did not change the outcome.

What am I doing wrong? Why is the parameter ‘meanGP’ not learned in the inferene?

The full Python code is:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import matplotlib.pyplot as plt
import torch

import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist

smoke_test = ('CI' in os.environ)  # ignore; used to check code integrity in the Pyro repo
assert pyro.__version__.startswith('0.4.1')
pyro.enable_validation(True)       # can help with debugging
pyro.set_rng_seed(0)

# note that this helper function does three different things:
# (i) plots the observed data;
# (ii) plots the predictions from the learned GP after conditioning on data;
# (iii) plots samples from the GP prior (with no conditioning on observed data)

def plot(plot_observed_data=False, plot_predictions=False, n_prior_samples=0,
         model=None, kernel=None, n_test=500):

    plt.figure(figsize=(12, 6))
    if plot_observed_data:
        plt.plot(X.numpy(), y.numpy(), 'kx')
    if plot_predictions:
        Xtest = torch.linspace(-0.5, 5.5, n_test)  # test inputs
        # compute predictive mean and variance
        with torch.no_grad():
            if type(model) == gp.models.VariationalSparseGP:
                mean, cov = model(Xtest, full_cov=True)
            else:
                mean, cov = model(Xtest, full_cov=True, noiseless=False)
        sd = cov.diag().sqrt()  # standard deviation at each input point x
        plt.plot(Xtest.numpy(), mean.numpy(), 'r', lw=2)  # plot the mean
        plt.fill_between(Xtest.numpy(),  # plot the two-sigma uncertainty about the mean
                         (mean - 2.0 * sd).numpy(),
                         (mean + 2.0 * sd).numpy(),
                         color='C0', alpha=0.3)
    if n_prior_samples > 0:  # plot samples from the GP prior
        Xtest = torch.linspace(-0.5, 5.5, n_test)  # test inputs
        noise = (model.noise if type(model) != gp.models.VariationalSparseGP
                 else model.likelihood.variance)
        cov = kernel.forward(Xtest) + noise.expand(n_test).diag()
        samples = dist.MultivariateNormal(torch.zeros(n_test), covariance_matrix=cov)\
                      .sample(sample_shape=(n_prior_samples,))
        plt.plot(Xtest.numpy(), samples.numpy().T, lw=2, alpha=0.4)

    plt.xlim(-0.5, 5.5)
    
N = 20
X = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))
y = 0.5 * torch.sin(3*X) + dist.Normal(0.0, 0.2).sample(sample_shape=(N,))+2.

plot(plot_observed_data=True)  # let's plot the observed data

kernel = gp.kernels.RBF(input_dim=1, variance=torch.tensor(5.),
                        lengthscale=torch.tensor(10.))
gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(1.),mean_function=lambda _:pyro.param("meanGP", torch.tensor(0.0)))

plot(model=gpr, kernel=kernel, n_prior_samples=2)

kernel2 = gp.kernels.RBF(input_dim=1, variance=torch.tensor(0.1),
                         lengthscale=torch.tensor(10.))
gpr2 = gp.models.GPRegression(X, y, kernel2, noise=torch.tensor(0.1))
plot(model=gpr2, kernel=kernel2, n_prior_samples=2)

#optimizer = torch.optim.Adam(gpr.parameters(), lr=0.005)
#loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
#losses = []
#num_steps = 2500 if not smoke_test else 2
#for i in range(num_steps):
#    optimizer.zero_grad()
#    loss = loss_fn(gpr.model, gpr.guide)
#    loss.backward()
#    optimizer.step()
#    losses.append(loss.item())


pyro.clear_param_store()
adam_params = {"lr": 0.005, "betas": (0.95, 0.999)}
optimizer = pyro.optim.Adam(adam_params)
# setup the inference algorithm
svi = pyro.infer.SVI(gpr.model, gpr.guide, optimizer, loss=pyro.infer.Trace_ELBO())

n_steps = 1000
# do gradient steps
for step in range(n_steps):
    svi.step()


plot(model=gpr, plot_observed_data=True, plot_predictions=True)

@summit Pyro GP module does not rely on param store so pyro.param primitive will not work. GP modules are subclasses of PyTorch nn.Module, so you only need to let your mean function to be an nn.Module. You can take a look at this example where the mean function is used to capture the linear trend of data.