Cannot get posterior predictive for simple GP

I am learning numpyro by coding a simple Gaussian Process. The problem is that the posterior predictive does not predict anything, samples seem to come from the prior. However inference with NUTS worked, because the usual formula for GP inference gives the correct posterior. Maybe I don’t understand how to use Predictive.

This is the code:

xL, yL, xU = get_dataset()

def gp_kernel(x, z, var, length, noise):
    g1, g2 = jnp.meshgrid(jnp.arange(len(x)), jnp.arange(len(z)))
    k = var * jnp.exp(-0.5 * jnp.sum((x[g1] - z[g2])**2, axis=-1) / length)
    if noise is not None:
        k += (noise + 1e-6) * jnp.eye(len(x))
    return k


def model(X, Y):
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 1))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 1))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 1))
    
    k = gp_kernel(X, X, var, length, noise)
    numpyro.sample(
        'f',
        dist.MultivariateNormal(loc=jnp.zeros(len(X)), covariance_matrix=k),
        obs=Y
    )

# fit model
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
mcmc.run(rng_key, xL, yL)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

# posterior via pyro
predictive = numpyro.infer.Predictive(model, samples_1)
preds = predictive(rng_key, xL, None)
preds['f'].mean(axis=0)  # this posterior mean is always very close to zero

# posterior via analytical formula
var = samples_1['kernel_var'].mean()
length = samples_1['kernel_length'].mean()
noise = samples_1['kernel_noise'].mean()
kUL = gp_kernel(xU, xL, var, length, None)
kUU = gp_kernel(xU, xU, var, length, noise)
kLL = gp_kernel(xL, xL, var, length, noise)
kLLi = jnp.linalg.inv(kLL)

mU = kUL.T @ kLLi @ yL  # this posterior mean is correct
sU = kUU - kUL.T @ kLLi @ kUL

Thanks!

1 Like

I think Predictive can’t be used in the context of GP that way. MCMC gives you p(theta|X,y). What you need to draw is from p(y_new|X_new,X,y), which is equal to the intergrate p(y_new|X_new,X,y,theta).p(theta|X,y). I’ll call the first term q(y_new). So you can take samples of theta from MCMC and use Predictive on q model, which follows your analytical formula. Your way of using Predictive is for p(y_new|X_new,theta), which is different from q(y_new). Predictive is used for prediction on the same model with new data only when we can assume that conditioned on latent variables, the livelihoods of new data and training data are independent, which is a typical assumption in many models, but not for GP.

Silly me, thank you for the explanation. Is there no other way of getting p(y_new|X_new,X,y) with pyro? I’d like to avoid deriving the analytical solution every time I make a change to the GP. I guess I’d need to resort to (S)VI?

I’m not sure there is an easy way to do. You might want to write an additional model for prediction, something like Gaussian Processes — Pyro documentation and abstract it out in some ways for your usage (like what we did for Pyro GP models).

1 Like