Why does Predictive silently accept wrong models?

Here’s a code example where I train a model, and then pass the params/guide from that model to predictive but pass the wrong model function and Predictive functions just fine (but in some cases returns the wrong answer).

What’s going on here?

import jax
import jax.numpy as jnp
import numpyro
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal

N = 10000
LR = 0.05
N_EPOCHS = 10000
K = 10
LOW_RANK_DIM = 3

adam = numpyro.optim.Adam(step_size=LR)


def model(X, y=None):
    alpha = numpyro.sample("alpha", numpyro.distributions.Normal(0, 1))
    beta = numpyro.sample("beta", numpyro.distributions.Normal(0, 1).expand([K]))
    sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
    y_hat = alpha + jnp.dot(X, beta)
    numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)


#### DATA GEN ####
rng_key1, rng_key2, rng_key3 = jax.random.split(jax.random.PRNGKey(42), num=3)
sigma = 1.0
X = numpyro.distributions.Normal(0, 1).sample(rng_key1, sample_shape=(N, K))
beta = numpyro.distributions.Normal(0, 1).sample(rng_key2, sample_shape=(K,))
mu = jnp.dot(X, beta)
y = numpyro.distributions.Normal(mu, sigma).sample(rng_key3, sample_shape=())

#### FIT ####
guide = AutoNormal(model)
svi = SVI(model, guide, optim=adam, loss=Trace_ELBO())
rng_key = jax.random.PRNGKey(99)

svi_result = svi.run(rng_key, N_EPOCHS, X=X, y=y)

#### WRONG MODELS ####


def model_with_too_few_args(X, y=None):
    beta = numpyro.sample("beta", numpyro.distributions.Normal(0, 1).expand([K]))
    sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
    y_hat = jnp.dot(X, beta)
    numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)


def model_with_too_many_args(X, y=None):
    alpha = numpyro.sample("alpha", numpyro.distributions.Normal(0, 1))
    beta1 = numpyro.sample("beta1", numpyro.distributions.Normal(0, 1).expand([K]))
    beta2 = numpyro.sample("beta2", numpyro.distributions.Normal(0, 1).expand([K]))
    sigma = numpyro.sample("sigma", numpyro.distributions.Exponential(1))
    y_hat = alpha + jnp.dot(X, beta1) + jnp.dot(X, beta2)
    numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)


#### GET PREDICTIVE ####

pred_params_plus_guide = Predictive(
    model_with_too_few_args,
    params=svi_result.params,
    guide=guide,
    num_samples=1000,
)
full_samples = pred_params_plus_guide(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples["obs"], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])

pred_params_plus_guide = Predictive(
    model_with_too_many_args,
    params=svi_result.params,
    guide=guide,
    num_samples=1000,
)
full_samples = pred_params_plus_guide(jax.random.PRNGKey(99), X=X)
y_hat = jnp.mean(full_samples["obs"], axis=0)
print(jnp.corrcoef(y.squeeze(), y_hat)[0, 1])

jax/numpyro are functional. so we need to explicitly pass params along with a model/guide definition to define an explicit object that can be evaluated. if you pass a “wrong” parameter dict that’s on you (e.g. it’ll ignore extraneous params). we could potentially add stronger checks to Predictive and if you’d like to do propose doing so i suggest you make a concrete suggestion in an issue or open an MR with code and/or docstring changes.

from my perspective the main intended use case of Predictive is:

  1. pass no samples/ guide (=> prior predictive)
  2. pass posterior samples (=> posterior predictive)
  3. as a pure convenience, instead of #2 you can pass a guide/params, the method generates samples, and then Predictive does the same as 2.

Thanks for explaining, that is very helpful. We’re going to write some internal stuff that does checks and we may come back with a code proposal for some validation. Thanks for being open to it and for discussing. :slightly_smiling_face: