Hello, I tried to construct the gaussian process model with student-t likelihood based on the following pages.
However, r_hat
cannot be less than 1.1. Is there still anything I can do to fix this problem? Thank you.
Code
import argparse
import os
import time
import numpy as np
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
MCMC,
NUTS,
init_to_feasible,
init_to_median,
init_to_sample,
init_to_uniform,
init_to_value,
)
def kernel(X, Z, var, length, jitter=1.0e-6):
deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
k = var * jnp.exp(-0.5 * deltaXsq) + jitter * jnp.eye(X.shape[0])
return k
def model(X, Y):
var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
noise = numpyro.sample("likelihood_noise", dist.LogNormal(0.0, 10.0))
df = numpyro.sample("likelihood_df", dist.LogNormal(0.0, 10.0))
k = kernel(X, X, var, length)
f = numpyro.sample(
"f",
dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
)
numpyro.sample("obs", dist.StudentT(df=df, loc=f, scale=noise), obs=Y)
def run_inference(model, args, rng_key, X, Y):
start = time.time()
if args.init_strategy == "value":
init_strategy = init_to_value(
values={"kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5}
)
elif args.init_strategy == "median":
init_strategy = init_to_median(num_samples=10)
elif args.init_strategy == "feasible":
init_strategy = init_to_feasible()
elif args.init_strategy == "sample":
init_strategy = init_to_sample()
elif args.init_strategy == "uniform":
init_strategy = init_to_uniform(radius=1)
kernel = NUTS(model, init_strategy=init_strategy)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
thinning=args.thinning,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, X, Y)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()
def get_data(N=30, sigma_obs=0.15, N_test=400):
np.random.seed(0)
X = jnp.linspace(-1, 1, N)
Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)
Y += sigma_obs * np.random.randn(N)
Y -= jnp.mean(Y)
Y /= jnp.std(Y)
assert X.shape == (N,)
assert Y.shape == (N,)
X_test = jnp.linspace(-1.3, 1.3, N_test)
return X, Y, X_test
def main(args):
X, Y, X_test = get_data(N=args.num_data)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, args, rng_key, X, Y)
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.13.2")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=3, type=int)
parser.add_argument("--thinning", nargs="?", default=2, type=int)
parser.add_argument("--num-data", nargs="?", default=25, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
parser.add_argument(
"--init-strategy",
default="median",
type=str,
choices=["median", "feasible", "value", "uniform", "sample"],
)
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
Output
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695711481.764211 570 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
Running chain 0: 100%|███████████████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 43.34it/s]
Running chain 1: 100%|███████████████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 43.34it/s]
Running chain 2: 100%|███████████████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 43.34it/s]
mean std median 5.0% 95.0% n_eff r_hat
f[0] -0.01 0.02 0.00 -0.04 0.01 1.50 29.28
f[1] -0.01 0.02 0.00 -0.04 0.00 1.50 34.17
f[2] -0.01 0.02 0.00 -0.04 0.00 1.50 23.70
f[3] -0.01 0.02 0.00 -0.04 0.00 1.50 30.59
f[4] -0.01 0.02 0.00 -0.04 0.00 1.50 28.47
f[5] -0.01 0.02 0.00 -0.04 0.00 1.50 25.04
f[6] -0.01 0.02 0.00 -0.04 0.01 1.50 21.95
f[7] -0.01 0.02 0.00 -0.04 0.01 1.50 32.25
f[8] -0.01 0.02 0.00 -0.04 0.01 1.50 32.69
f[9] -0.01 0.02 0.00 -0.04 0.00 1.50 33.16
f[10] -0.01 0.02 0.00 -0.04 0.01 1.51 25.94
f[11] -0.01 0.02 0.00 -0.04 0.01 1.50 25.57
f[12] -0.01 0.02 0.00 -0.04 0.01 1.50 30.18
f[13] -0.01 0.02 0.00 -0.04 0.00 1.50 21.76
f[14] -0.01 0.02 0.00 -0.04 0.01 1.50 28.31
f[15] -0.01 0.02 -0.00 -0.04 0.01 1.51 22.66
f[16] -0.01 0.02 0.00 -0.04 0.01 1.51 22.58
f[17] -0.01 0.02 0.00 -0.04 0.01 1.51 21.02
f[18] -0.01 0.02 0.00 -0.04 0.01 1.50 23.40
f[19] -0.01 0.02 0.00 -0.04 0.01 1.50 30.69
f[20] -0.01 0.02 0.00 -0.04 0.00 1.50 24.92
f[21] -0.01 0.02 0.00 -0.04 0.00 1.50 26.15
f[22] -0.01 0.02 0.00 -0.04 0.01 1.50 28.08
f[23] -0.01 0.02 0.00 -0.04 0.00 1.50 26.85
f[24] -0.01 0.02 0.00 -0.04 0.00 1.50 29.77
kernel_length 3063.60 3376.91 1322.14 84.83 7787.93 1.50 954.54
kernel_var 0.04 0.06 0.01 0.00 0.12 1.50 730.53
likelihood_df 1206050.88 146733.44 1150100.00 1060936.38 1409329.50 1.50 128.50
likelihood_noise 1.05 0.14 1.12 0.86 1.18 1.50 125.45
Number of divergences: 0
MCMC elapsed time: 46.6217155456543