I’m trying to adapt the Hilbert Space GP Approx to a multivariate example (the Friedman function) without much success.
import argparse
import os
import matplotlib.pyplot as plt
import pandas as pd
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
import numpyro
from numpyro import deterministic, plate, sample
import numpyro.distributions as dist
from numpyro.handlers import scope
from numpyro.infer import MCMC, NUTS, init_to_median
# --- Modelling utility functions --- #
def spectral_density(w, alpha, length):
c = alpha * jnp.sqrt(2 * jnp.pi) * length
e = jnp.exp(-0.5 * (length**2) * (w**2))
return c * e
def diag_spectral_density(alpha, length, L, M):
sqrt_eigenvalues = jnp.arange(1, 1 + M) * jnp.pi / 2 / L
return spectral_density(sqrt_eigenvalues, alpha, length)
def eigenfunctions(x, L, M):
"""
The first `M` eigenfunctions of the laplacian operator in `[-L, L]`
evaluated at `x`. These are used for the approximation of the
squared exponential kernel.
"""
m1 = (jnp.pi / (2 * L)) * jnp.tile(L + x[:, None], M)
m2 = jnp.diag(jnp.linspace(1, M, num=M))
print (m1.shape)
print (m2.shape)
num = jnp.sin(m1 @ m2)
den = jnp.sqrt(L)
return num / den
def modified_bessel_first_kind(v, z):
v = jnp.asarray(v, dtype=float)
return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z)
def diag_spectral_density_periodic(alpha, length, M):
"""
Not actually a spectral density but these are used in the same
way. These are simply the first `M` coefficients of the low rank
approximation for the periodic kernel.
"""
a = length ** (-2)
J = jnp.arange(0, M)
c = jnp.where(J > 0, 2, 1)
q2 = (c * alpha**2 / jnp.exp(a)) * modified_bessel_first_kind(J, a)
return q2
def eigenfunctions_periodic(x, w0, M):
"""
Basis functions for the approximation of the periodic kernel.
"""
m1 = jnp.tile(w0 * x[:, None], M)
m2 = jnp.diag(jnp.arange(M, dtype=jnp.float32))
mw0x = m1 @ m2
cosines = jnp.cos(mw0x)
sines = jnp.sin(mw0x)
return cosines, sines
# --- Approximate Gaussian processes --- #
def approx_se_ncp(x, alpha, length, L, M):
"""
Hilbert space approximation for the squared
exponential kernel in the non-centered parametrisation.
"""
phi = eigenfunctions(x, L, M)
spd = jnp.sqrt(diag_spectral_density(alpha, length, L, M))
with plate("basis", M):
beta = sample("beta", dist.Normal(0, 1))
f = deterministic("f", phi @ (spd * beta))
return f
def approx_periodic_gp_ncp(x, alpha, length, w0, M):
"""
Low rank approximation for the periodic squared
exponential kernel in the non-centered parametrisation.
"""
q2 = diag_spectral_density_periodic(alpha, length, M)
cosines, sines = eigenfunctions_periodic(x, w0, M)
with plate("cos_basis", M):
beta_cos = sample("beta_cos", dist.Normal(0, 1))
with plate("sin_basis", M - 1):
beta_sin = sample("beta_sin", dist.Normal(0, 1))
# The first eigenfunction for the sine component
# is zero, so the first parameter wouldn't contribute to the approximation.
# We set it to zero to identify the model and avoid divergences.
zero = jnp.array([0.0])
beta_sin = jnp.concatenate((zero, beta_sin))
f = deterministic("f", cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin))
return f
# --- Components of the Birthdays model --- #
def trend_gp(x, L, M):
alpha = sample("alpha", dist.HalfNormal(1.0))
length = sample("length", dist.InverseGamma(10.0, 2.0))
f = approx_se_ncp(x, alpha, length, L, M)
return f
def year_gp(x, w0, M):
alpha = sample("alpha", dist.HalfNormal(1.0))
length = sample("length", dist.HalfNormal(0.2)) # scale=0.1 in original
f = approx_periodic_gp_ncp(x, alpha, length, w0, M)
return f
# --- Model --- #
def birthdays_model(
x,
y=None,
):
intercept = sample("intercept", dist.Normal(0, 1))
f1 = scope(trend_gp, "trend")(x, 1.5, 10)
# --- Combine components
f = deterministic("f", intercept + f1 )
sigma = sample("sigma", dist.HalfNormal(0.5))
with plate("obs", x.shape[0]):
sample("y", dist.Normal(f, sigma), obs=y)
# --- functions for running the model --- #
import numpy as np
x = np.random.uniform(0,1,(1000,5))
def friedman(x):
10*np.sin(np.pi*x[:,0]*x[:,1])+20*(x[:,2]-.5)**2+10*x[:,3]+5*x[:,4]
y = friedman(x)
mcmc = MCMC(
NUTS(birthdays_model, init_strategy=init_to_median),
num_warmup=500,
num_samples=1000,
num_chains=1)
mcmc.run(jax.random.PRNGKey(0), x=x,y=y)
mcmc.print_summary()
samples = mcmc.get_samples()
I think I don’t fully understand what’s happening in the creation of m1
and m2
and why I’m getting the dimensions I’m getting. Could anyone take a look?