Here’s a code example where I train a model, and then pass the params/guide from that model to predictive but pass the wrong model function and Predictive
functions just fine (but in some cases returns the wrong answer).
What’s going on here?
import jax
import jax.numpy as jnp
import numpyro
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
N = 10000
LR = 0.05
N_EPOCHS = 10000
K = 10
LOW_RANK_DIM = 3
adam = numpyro.optim.Adam(step_size=LR)
def model(X, y=None):
alpha = numpyro.sample("alpha", numpyro.distributions.Normal(0, 1))
beta = numpyro.sample("beta", numpyro.distributions.Normal(0, 1).expand([K]))
sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
y_hat = alpha + jnp.dot(X, beta)
numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)
#### DATA GEN ####
rng_key1, rng_key2, rng_key3 = jax.random.split(jax.random.PRNGKey(42), num=3)
sigma = 1.0
X = numpyro.distributions.Normal(0, 1).sample(rng_key1, sample_shape=(N, K))
beta = numpyro.distributions.Normal(0, 1).sample(rng_key2, sample_shape=(K,))
mu = jnp.dot(X, beta)
y = numpyro.distributions.Normal(mu, sigma).sample(rng_key3, sample_shape=())
#### FIT ####
guide = AutoNormal(model)
svi = SVI(model, guide, optim=adam, loss=Trace_ELBO())
rng_key = jax.random.PRNGKey(99)
svi_result = svi.run(rng_key, N_EPOCHS, X=X, y=y)
#### WRONG MODELS ####
def model_with_too_few_args(X, y=None):
beta = numpyro.sample("beta", numpyro.distributions.Normal(0, 1).expand([K]))
sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
y_hat = jnp.dot(X, beta)
numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)
def model_with_too_many_args(X, y=None):
alpha = numpyro.sample("alpha", numpyro.distributions.Normal(0, 1))
beta1 = numpyro.sample("beta1", numpyro.distributions.Normal(0, 1).expand([K]))
beta2 = numpyro.sample("beta2", numpyro.distributions.Normal(0, 1).expand([K]))
sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
y_hat = alpha + jnp.dot(X, beta1) + jnp.dot(X, beta2)
numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)
#### GET PREDICTIVE ####
pred_params_plus_guide = Predictive(
model_with_too_few_args,
params=svi_result.params,
guide=guide,
num_samples=1000,
)
full_samples = pred_params_plus_guide(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples["obs"], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])
pred_params_plus_guide = Predictive(
model_with_too_many_args,
params=svi_result.params,
guide=guide,
num_samples=1000,
)
full_samples = pred_params_plus_guide(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples["obs"], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])