Hi, thanks for the quick reply! Here is a complete runnable script. A big part is just getting a toy dataset. The actual model bit is towards the end. The last line gives an error. I was looking into using plates but am not sure how to use a plate to index the particular input-output pair or hyperparameters for each site.
import os
import time
import jax.numpy as np
from jax import random
from jax import vmap, jit
from functools import partial
import numpyro
from numpyro.infer import init_to_median, Predictive, MCMC, NUTS
import numpyro.distributions as dist
@partial(jit, static_argnums=(5, 6))
def rbf_covariance(var, length, noise, x, xp, jitter=1.0e-6,
include_noise=True):
diff = np.expand_dims(x / length, 1) - np.expand_dims(xp / length, 0)
Z = var * np.exp(-0.5 * np.sum(diff**2, axis=2)) # ! axis = 2 ??
if include_noise:
return Z + (noise + jitter) * np.eye(x.shape[0])
else:
return Z
## generate dataset
###############################
def univariate_gp(x, y, mu, var, noise, length, jitter=1e-06):
# compute kernel
K = rbf_covariance(var, length, noise, x.reshape(-1,1), x.reshape(-1,1), jitter=jitter)
numpyro.sample(
"obs_y",
dist.MultivariateNormal(loc=mu, covariance_matrix=K),
obs=y,
)
def gen_outer_gp(key, proj_data, mu, var, noise, length, jitter=1e-06):
predictive = Predictive(univariate_gp, num_samples=1)
pred = predictive(key, proj_data, None, mu, var, noise, length, jitter=jitter)
return pred['obs_y'].flatten()
input_key = random.PRNGKey(778989)
projs_key = random.PRNGKey(3257)
outer_key = random.PRNGKey(2357)
D = 2
N = 50
n_s = 2
var = 1.0
noise = 0.1
length = 0.5
x1 = np.linspace(0,1,N).reshape(-1,1)
x2 = np.linspace(1,2,N).reshape(-1,1)
Xs = np.hstack((x1,x2))
def mu_f(x):
y = x + 0.2 * (x ** 3) + 0.5 * ((0.5 + x) ** 2) * np.sin(4.0 * x)
return y
mu_s = vmap(mu_f, in_axes=1, out_axes=1)(Xs)
outer_keys = random.split(outer_key, n_s)
Ys = vmap(lambda key, x, mu: gen_outer_gp(key, x, mu, var, noise, length), in_axes=(0,1,1), out_axes=1)(outer_keys, Xs, mu_s)
#######################################################
## model
def model(Xs, Ys, jitter=1e-6):
N, n_s = Xs.shape
for i in range(n_s):
x = Xs[:,i]
y = Ys[:,i]
K = rbf_covariance(var, length, noise, x.reshape(-1,1), x.reshape(-1,1), jitter=jitter)
numpyro.sample(
f"Ys[:,{i}]",
dist.MultivariateNormal(loc=np.zeros(K.shape[0]), covariance_matrix=K),
obs=y
)
# options for model
mcmc_config = {'num_warmup' : 1000, 'num_samples' : 1000, 'num_chains' : 1, 'thinning' : 2, 'init_strategy' : init_to_median(num_samples=10)}
#
seed = 342757
train_key = random.PRNGKey(seed)
# helper function for doing hmc inference
def run_inference(rng_key, mcmc_config, model, *args):
num_warmup = mcmc_config['num_warmup']
num_samples = mcmc_config['num_samples']
num_chains = mcmc_config['num_chains']
thinning = mcmc_config['thinning']
init_strategy = mcmc_config['init_strategy']
start = time.time()
kernel = NUTS(model, init_strategy=init_strategy)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, thinning=thinning,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(rng_key, *args)
print('\nMCMC elapsed time:', time.time() - start)
return mcmc
mcmc = run_inference(train_key, mcmc_config, model, Xs, Ys)
# throws error
mcmc.print_summary()
# this would return {}
# mcmc.get_samples()