Different arguments to Predictive give different results in SVI?

Hey folks, have been using numpyro for a bit. I noticed some behavior about Predictive that seems weird to me. I was hoping to get some clarification. Basically, Predictive gives different answers whether you pass posterior_samples, guide, params, or both guide+params. Why this should be the case isn’t intuitive to me.

Here’s a complete example.

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal

N = 10000
LR = 0.05
N_EPOCHS = 10000
K = 10
LOW_RANK_DIM = 3

adam = numpyro.optim.Adam(step_size=LR)

def model(X, y=None):
    alpha = numpyro.sample("alpha", numpyro.distributions.Normal(0, 1))
    beta = numpyro.sample("beta", numpyro.distributions.Normal(0, 1).expand([K]))
    sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
    y_hat = alpha + jnp.dot(X, beta)
    numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)

#### DATA GEN ####
rng_key1, rng_key2, rng_key3 = jax.random.split(jax.random.PRNGKey(42), num=3)
sigma = 1.0
X = numpyro.distributions.Normal(0, 1).sample(rng_key1, sample_shape=(N, K))
beta = numpyro.distributions.Normal(0, 1).sample(rng_key2, sample_shape=(K,))
mu = jnp.dot(X, beta)
y = numpyro.distributions.Normal(mu, sigma).sample(rng_key3, sample_shape=())

#### FIT ####
guide = AutoNormal(model)
svi = SVI(model, guide, optim=adam, loss=Trace_ELBO())
rng_key = jax.random.PRNGKey(99)

svi_result = svi.run(rng_key, N_EPOCHS, X=X, y=y)

### RESULTS ####

# PARAMS ONLY -- wrong, 0.086 corr with true value
pred_params_only = Predictive(model, params=svi_result.params, num_samples=1000)
full_samples = pred_params_only(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])

# GUIDE ONLY -- wrong, 0.223 corr with true value
pred_guide_only = Predictive(model, guide=guide, num_samples=1000)
full_samples = pred_guide_only(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])

# BOTH -- corect
pred_params_plus_guide = Predictive(model, params=svi_result.params, guide=guide, num_samples=1000)
full_samples = pred_params_plus_guide(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])

# PASS POSTERIOR DIRECTLY -- correct
guide_pred = Predictive(guide, params=svi_result.params, num_samples=1000)
posterior_samples = guide_pred(jax.random.PRNGKey(1), X=X)
pred_posterior = Predictive(model, posterior_samples=posterior_samples, num_samples=1000)
full_samples = pred_posterior(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples['obs'], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])

Note that using Predictive with only params or only guide give wrong answers, but different wrong answers. Passing both params and guide works, as does passing the posterior samples directly.

Hoping to get some insight into what’s going on under the hood and the rationale for this behavior.

afaik this is roughly what’s going on:

  • before you train the guide, it has some initial (random) parameters. if you don’t pass in trained params the guide will be executed with those initial params, which obviously generates random results.
  • conversely if you pass in params without a guide then afaik it’ll ignore the params and you’ll get samples from the model prior, i.e. you’ll get a prior predictive. arguably Predictive should generate an error if params is provided without guide

Thanks for the response, we’re working on using numpyro at our company and this behavior really confused us. What we settled on was passing in posterior samples and writing a wrapper around Predictive to force people to pass in the right arguments.