Hi,
I’m following the turorial on hierarchical_linear_regression and trying to apply a multilevel model to new, unseen data.
The basic idea I want to implement is similar to this out-of-sample prediction example from PyMC.
The challenge is that I’m using an index as an input to the model, and the group-level posterior samples aren’t directly compatible with new data containing unseen indices.
I’m attaching a simple reproducible example below. While the code runs, I believe the last step—posterior predictive sampling with test data—is not working as intended.
I’d appreciate any guidance on how to handle this correctly. Thank you!
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from jax import numpy as jnp
from jax import random
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from sklearn.preprocessing import LabelEncoder
# Load the dataset
df = pd.read_csv(
"https://gist.githubusercontent.com/ucals/"
"2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"
"43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"
"osic_pulmonary_fibrosis.csv"
)
df["patient_code"] = LabelEncoder().fit_transform(df["Patient"].values)
# Split based on patient_code
df_train = df[df["patient_code"] <= 99].reset_index(drop=True)
df_test = df[df["patient_code"] > 99].reset_index(drop=True)
# Convert to JAX arrays
FVC_obs_train = jnp.asarray(df_train["FVC"].values)
FVC_obs_test = jnp.asarray(df_test["FVC"].values)
Weeks_train = jnp.asarray(df_train["Weeks"].values)
Weeks_test = jnp.asarray(df_test["Weeks"].values)
patient_code_train = jnp.asarray(df_train["patient_code"].values)
patient_code_test = jnp.asarray(df_test["patient_code"].values)
# Verify the split
print("Train set unique patients:", df_train["Patient"].nunique())
print("Test set unique patients:", df_test["Patient"].nunique())
print(
"Overlap in patients:",
set(df_train["Patient"]).intersection(set(df_test["Patient"])),
) # Should be empty
print(
"Overlap in patients:",
set(df_train["patient_code"]).intersection(set(df_test["patient_code"])),
) # Should be empty
# Model
def model(patient_code, Weeks, FVC_obs=None):
n_patients = len(np.unique(patient_code))
μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
with numpyro.plate("plate_i", n_patients):
α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
β = numpyro.sample("β", dist.Normal(0, 1))
FVC_est = α[patient_code] + β * Weeks
with numpyro.plate("data", len(patient_code)):
numpyro.sample("obs", dist.Normal(FVC_est, 1), obs=FVC_obs)
# VI
rng_key = random.key(0)
vi = SVI(
model,
guide=AutoNormal(model=model),
optim=numpyro.optim.Adam(step_size=1e-3),
loss=Trace_ELBO(),
)
rng_key, rng_subkey = random.split(key=rng_key)
result = vi.run(
rng_subkey,
num_steps=1000,
patient_code=patient_code_train,
Weeks=Weeks_train,
FVC_obs=FVC_obs_train,
)
# Posterior sampling
predict1 = Predictive(
vi.guide,
params=result.params,
num_samples=1000,
)
rng_key, rng_subkey = random.split(key=rng_key)
posterior_samples = predict1(rng_subkey)
# Posterior predictive sampling (in-sample)
predict2 = Predictive(model, posterior_samples=posterior_samples)
rng_key, rng_subkey = random.split(key=rng_key)
posterior_predictive_samples = predict2(
rng_key, patient_code=patient_code_train, Weeks=Weeks_train
)
and this is where I get confused.
predict2(rng_key, patient_code=patient_code_test, Weeks=Weeks_test)