How to construct the gaussian process model with student-t likelihood?

Hello, I tried to construct the gaussian process model with student-t likelihood based on the following pages.

However, r_hat cannot be less than 1.1. Is there still anything I can do to fix this problem? Thank you.

Code

import argparse
import os
import time

import numpy as np

import jax.numpy as jnp
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
    MCMC,
    NUTS,
    init_to_feasible,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
)


def kernel(X, Z, var, length, jitter=1.0e-6):
    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq) + jitter * jnp.eye(X.shape[0])
    return k


def model(X, Y):
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("likelihood_noise", dist.LogNormal(0.0, 10.0))
    df = numpyro.sample("likelihood_df", dist.LogNormal(0.0, 10.0))

    k = kernel(X, X, var, length)
    f = numpyro.sample(
        "f",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
    )
    numpyro.sample("obs", dist.StudentT(df=df, loc=f, scale=noise), obs=Y)


def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    if args.init_strategy == "value":
        init_strategy = init_to_value(
            values={"kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5}
        )
    elif args.init_strategy == "median":
        init_strategy = init_to_median(num_samples=10)
    elif args.init_strategy == "feasible":
        init_strategy = init_to_feasible()
    elif args.init_strategy == "sample":
        init_strategy = init_to_sample()
    elif args.init_strategy == "uniform":
        init_strategy = init_to_uniform(radius=1)
    kernel = NUTS(model, init_strategy=init_strategy)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        thinning=args.thinning,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc.get_samples()


def get_data(N=30, sigma_obs=0.15, N_test=400):
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)
    Y += sigma_obs * np.random.randn(N)
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N,)
    assert Y.shape == (N,)

    X_test = jnp.linspace(-1.3, 1.3, N_test)

    return X, Y, X_test


def main(args):
    X, Y, X_test = get_data(N=args.num_data)
    rng_key, rng_key_predict = random.split(random.PRNGKey(0))
    samples = run_inference(model, args, rng_key, X, Y)


if __name__ == "__main__":
    assert numpyro.__version__.startswith("0.13.2")
    parser = argparse.ArgumentParser(description="Gaussian Process example")
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
    parser.add_argument("--num-chains", nargs="?", default=3, type=int)
    parser.add_argument("--thinning", nargs="?", default=2, type=int)
    parser.add_argument("--num-data", nargs="?", default=25, type=int)
    parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
    parser.add_argument(
        "--init-strategy",
        default="median",
        type=str,
        choices=["median", "feasible", "value", "uniform", "sample"],
    )
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)

Output

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695711481.764211     570 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
Running chain 0: 100%|███████████████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 43.34it/s]
Running chain 1: 100%|███████████████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 43.34it/s]
Running chain 2: 100%|███████████████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 43.34it/s]

                        mean       std    median      5.0%     95.0%     n_eff     r_hat
              f[0]     -0.01      0.02      0.00     -0.04      0.01      1.50     29.28
              f[1]     -0.01      0.02      0.00     -0.04      0.00      1.50     34.17
              f[2]     -0.01      0.02      0.00     -0.04      0.00      1.50     23.70
              f[3]     -0.01      0.02      0.00     -0.04      0.00      1.50     30.59
              f[4]     -0.01      0.02      0.00     -0.04      0.00      1.50     28.47
              f[5]     -0.01      0.02      0.00     -0.04      0.00      1.50     25.04
              f[6]     -0.01      0.02      0.00     -0.04      0.01      1.50     21.95
              f[7]     -0.01      0.02      0.00     -0.04      0.01      1.50     32.25
              f[8]     -0.01      0.02      0.00     -0.04      0.01      1.50     32.69
              f[9]     -0.01      0.02      0.00     -0.04      0.00      1.50     33.16
             f[10]     -0.01      0.02      0.00     -0.04      0.01      1.51     25.94
             f[11]     -0.01      0.02      0.00     -0.04      0.01      1.50     25.57
             f[12]     -0.01      0.02      0.00     -0.04      0.01      1.50     30.18
             f[13]     -0.01      0.02      0.00     -0.04      0.00      1.50     21.76
             f[14]     -0.01      0.02      0.00     -0.04      0.01      1.50     28.31
             f[15]     -0.01      0.02     -0.00     -0.04      0.01      1.51     22.66
             f[16]     -0.01      0.02      0.00     -0.04      0.01      1.51     22.58
             f[17]     -0.01      0.02      0.00     -0.04      0.01      1.51     21.02
             f[18]     -0.01      0.02      0.00     -0.04      0.01      1.50     23.40
             f[19]     -0.01      0.02      0.00     -0.04      0.01      1.50     30.69
             f[20]     -0.01      0.02      0.00     -0.04      0.00      1.50     24.92
             f[21]     -0.01      0.02      0.00     -0.04      0.00      1.50     26.15
             f[22]     -0.01      0.02      0.00     -0.04      0.01      1.50     28.08
             f[23]     -0.01      0.02      0.00     -0.04      0.00      1.50     26.85
             f[24]     -0.01      0.02      0.00     -0.04      0.00      1.50     29.77
     kernel_length   3063.60   3376.91   1322.14     84.83   7787.93      1.50    954.54
        kernel_var      0.04      0.06      0.01      0.00      0.12      1.50    730.53
     likelihood_df 1206050.88 146733.44 1150100.00 1060936.38 1409329.50      1.50    128.50
  likelihood_noise      1.05      0.14      1.12      0.86      1.18      1.50    125.45

Number of divergences: 0

MCMC elapsed time: 46.6217155456543

not sure what’s going on but you might try tighter priors on the kernel hyperparameters (e.g. likelihood_df) as well as doing things in 64-bit precision:

from numpyro.util import enable_x64
enable_x64()

also you might try increasing the jitter to e.g. 1.0e-4

1 Like

Thank you so much. The use of 64-bit precision improved dramatically. In addition, when dist.LogNormal(0.0, 3.0) is set to the likelihood_df’s prior, all r_hats are now less than 1.1.

Code

import argparse
import os
import time

import numpy as np

import jax.numpy as jnp
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
    MCMC,
    NUTS,
    init_to_feasible,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
)
from numpyro.util import enable_x64


def kernel(X, Z, var, length, jitter=1.0e-4):
    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq) + jitter * jnp.eye(X.shape[0])
    return k


def model(X, Y):
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("likelihood_noise", dist.LogNormal(0.0, 10.0))
    df = numpyro.sample("likelihood_df", dist.LogNormal(0.0, 3.0))

    k = kernel(X, X, var, length)
    f = numpyro.sample(
        "f",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
    )
    numpyro.sample("obs", dist.StudentT(df=df, loc=f, scale=noise), obs=Y)


def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    if args.init_strategy == "value":
        init_strategy = init_to_value(
            values={"kernel_var": 1.0, "kernel_noise": 0.05, "kernel_length": 0.5}
        )
    elif args.init_strategy == "median":
        init_strategy = init_to_median(num_samples=10)
    elif args.init_strategy == "feasible":
        init_strategy = init_to_feasible()
    elif args.init_strategy == "sample":
        init_strategy = init_to_sample()
    elif args.init_strategy == "uniform":
        init_strategy = init_to_uniform(radius=1)
    kernel = NUTS(model, init_strategy=init_strategy)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        thinning=args.thinning,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc.get_samples()


def get_data(N=30, sigma_obs=0.15, N_test=400):
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)
    Y += sigma_obs * np.random.randn(N)
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N,)
    assert Y.shape == (N,)

    X_test = jnp.linspace(-1.3, 1.3, N_test)

    return X, Y, X_test


def main(args):
    X, Y, X_test = get_data(N=args.num_data)
    rng_key, rng_key_predict = random.split(random.PRNGKey(0))
    samples = run_inference(model, args, rng_key, X, Y)


if __name__ == "__main__":
    enable_x64()
    assert numpyro.__version__.startswith("0.13.2")
    parser = argparse.ArgumentParser(description="Gaussian Process example")
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
    parser.add_argument("--num-chains", nargs="?", default=3, type=int)
    parser.add_argument("--thinning", nargs="?", default=2, type=int)
    parser.add_argument("--num-data", nargs="?", default=25, type=int)
    parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
    parser.add_argument(
        "--init-strategy",
        default="median",
        type=str,
        choices=["median", "feasible", "value", "uniform", "sample"],
    )
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)

Result

                        mean       std    median      5.0%     95.0%     n_eff     r_hat
              f[0]     -1.46      0.17     -1.46     -1.78     -1.21   1242.22      1.00
              f[1]     -1.41      0.13     -1.41     -1.63     -1.21   1180.60      1.00
              f[2]     -1.33      0.11     -1.33     -1.51     -1.16   1031.77      1.00
              f[3]     -1.25      0.10     -1.24     -1.42     -1.08    902.38      1.00
              f[4]     -1.15      0.11     -1.15     -1.32     -0.98    657.19      1.00
              f[5]     -1.05      0.11     -1.05     -1.23     -0.88    690.11      1.00
              f[6]     -0.94      0.10     -0.94     -1.10     -0.78    931.70      1.00
              f[7]     -0.82      0.09     -0.83     -0.98     -0.67   1040.29      1.00
              f[8]     -0.68      0.09     -0.68     -0.84     -0.54   1031.20      1.00
              f[9]     -0.53      0.09     -0.53     -0.67     -0.37    965.68      1.00
             f[10]     -0.35      0.10     -0.35     -0.51     -0.19    794.57      1.00
             f[11]     -0.15      0.10     -0.15     -0.31      0.02    720.22      1.00
             f[12]      0.06      0.10      0.06     -0.10      0.23    721.90      1.00
             f[13]      0.28      0.10      0.29      0.13      0.44    783.99      1.00
             f[14]      0.51      0.09      0.51      0.36      0.65    916.05      1.00
             f[15]      0.72      0.09      0.72      0.56      0.86    908.95      1.00
             f[16]      0.91      0.10      0.91      0.75      1.08    706.05      1.00
             f[17]      1.06      0.11      1.06      0.90      1.26    599.81      1.00
             f[18]      1.17      0.12      1.17      0.97      1.35    583.89      1.00
             f[19]      1.22      0.12      1.22      1.04      1.41    597.75      1.00
             f[20]      1.22      0.11      1.22      1.06      1.42    749.83      1.00
             f[21]      1.18      0.11      1.18      1.00      1.38    933.33      1.00
             f[22]      1.09      0.12      1.09      0.90      1.30   1009.75      1.00
             f[23]      0.96      0.15      0.96      0.74      1.20    983.35      1.00
             f[24]      0.82      0.19      0.81      0.54      1.14   1020.20      1.00
     kernel_length      0.71      0.25      0.66      0.37      1.05    574.83      1.00
        kernel_var      3.52     11.92      1.43      0.19      6.62   1128.09      1.00
     likelihood_df    159.28    884.33     18.44      0.46    220.82   1134.14      1.00
  likelihood_noise      0.22      0.05      0.22      0.14      0.30    941.22      1.00