Hello @fehiepsi
Well, I’m a bit piuzzle as here is a adaptation of one of your favorite exemple and AutoMVN & TraceMeanField_ELBO seems ok togeteher
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import expit
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, TraceMeanField_ELBO, autoguide
from numpyro.util import enable_x64
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rc('image', cmap='jet')
mpl.rcParams['font.size'] = 16
mpl.rcParams["font.family"] = "Times New Roman"
# squared exponential kernel
def kernel(X, Z, length, jitter=1.0e-6):
deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
k = jnp.exp(-0.5 * deltaXsq) + jitter * jnp.eye(X.shape[0])
return k
def model(X, Y, length=0.2):
# compute kernel
k = kernel(X, X, length)
# sample from gaussian process prior
f = numpyro.sample(
"f",
dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
)
# we use a non-standard link function to induce extra non-gaussianity
numpyro.sample("obs", dist.Bernoulli(logits=jnp.power(f, 3.0)), obs=Y)
# create artificial binary classification dataset
def get_data(N=16):
np.random.seed(0)
X = np.linspace(-1, 1, N)
Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)
Y -= np.mean(Y)
Y /= np.std(Y)
Y = np.random.binomial(1, expit(Y))
assert X.shape == (N,)
assert Y.shape == (N,)
return X, Y
# helper function for running SVI with a particular autoguide
def run_svi(rng_key, X, Y, guide_family="AutoDiagonalNormal", K=8, loss=None):
assert guide_family in ["AutoDiagonalNormal", "AutoDAIS", "AutoMultivariateNormal"]
if guide_family == "AutoDAIS":
guide = autoguide.AutoDAIS(model, K=K, eta_init=0.02, eta_max=0.5)
step_size = 5e-4
elif guide_family == "AutoMultivariateNormal":
guide = autoguide.AutoMultivariateNormal(model)
step_size = 3e-3
optimizer = numpyro.optim.Adam(step_size=step_size)
svi = SVI(model, guide, optimizer, loss=loss())
svi_result = svi.run(rng_key, 20_000, X, Y)
params = svi_result.params
final_elbo = -loss(num_particles=1000).loss(
rng_key, params, model, guide, X, Y
)
guide_name = guide_family
if guide_family == "AutoDAIS":
guide_name += "-{}".format(K)
print("[{}] final elbo: {:.2f}".format(guide_name, final_elbo))
return guide.sample_posterior(
random.PRNGKey(1), params, sample_shape=(1000,)
)
# helper function for running mcmc
def run_nuts(mcmc_key, args, X, Y):
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
mcmc.run(mcmc_key, X, Y)
mcmc.print_summary()
return mcmc.get_samples()
enable_x64()
X, Y = get_data()
rng_keys = random.split(random.PRNGKey(0), 4)
run_svi(rng_keys[1], X, Y, guide_family="AutoDAIS", K=8, loss=Trace_ELBO)
run_svi(rng_keys[1], X, Y, guide_family="AutoMultivariateNormal", loss=Trace_ELBO)
run_svi(rng_keys[1], X, Y, guide_family="AutoMultivariateNormal", loss=TraceMeanField_ELBO)
there is no TraceBack for the last call.