Hey folks, have been using numpyro for a bit. I noticed some behavior about Predictive that seems weird to me. I was hoping to get some clarification. Basically, Predictive
gives different answers whether you pass posterior_samples
, guide
, params
, or both guide
+params
. Why this should be the case isn’t intuitive to me.
Here’s a complete example.
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
N = 10000
LR = 0.05
N_EPOCHS = 10000
K = 10
LOW_RANK_DIM = 3
adam = numpyro.optim.Adam(step_size=LR)
def model(X, y=None):
alpha = numpyro.sample("alpha", numpyro.distributions.Normal(0, 1))
beta = numpyro.sample("beta", numpyro.distributions.Normal(0, 1).expand([K]))
sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
y_hat = alpha + jnp.dot(X, beta)
numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)
#### DATA GEN ####
rng_key1, rng_key2, rng_key3 = jax.random.split(jax.random.PRNGKey(42), num=3)
sigma = 1.0
X = numpyro.distributions.Normal(0, 1).sample(rng_key1, sample_shape=(N, K))
beta = numpyro.distributions.Normal(0, 1).sample(rng_key2, sample_shape=(K,))
mu = jnp.dot(X, beta)
y = numpyro.distributions.Normal(mu, sigma).sample(rng_key3, sample_shape=())
#### FIT ####
guide = AutoNormal(model)
svi = SVI(model, guide, optim=adam, loss=Trace_ELBO())
rng_key = jax.random.PRNGKey(99)
svi_result = svi.run(rng_key, N_EPOCHS, X=X, y=y)
### RESULTS ####
# PARAMS ONLY -- wrong, 0.086 corr with true value
pred_params_only = Predictive(model, params=svi_result.params, num_samples=1000)
full_samples = pred_params_only(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])
# GUIDE ONLY -- wrong, 0.223 corr with true value
pred_guide_only = Predictive(model, guide=guide, num_samples=1000)
full_samples = pred_guide_only(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])
# BOTH -- corect
pred_params_plus_guide = Predictive(model, params=svi_result.params, guide=guide, num_samples=1000)
full_samples = pred_params_plus_guide(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])
# PASS POSTERIOR DIRECTLY -- correct
guide_pred = Predictive(guide, params=svi_result.params, num_samples=1000)
posterior_samples = guide_pred(jax.random.PRNGKey(1), X=X)
pred_posterior = Predictive(model, posterior_samples=posterior_samples, num_samples=1000)
full_samples = pred_posterior(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])
Note that using Predictive
with only params
or only guide
give wrong answers, but different wrong answers. Passing both params
and guide
works, as does passing the posterior samples directly.
Hoping to get some insight into what’s going on under the hood and the rationale for this behavior.