# Predicting for Schools Outside the Training Range using NumPyro's Predictive Class

Hello everyone,

I have a hierarchical model written in NumPyro that predicts the academic performance of schools based on various factors. The model is trained using data from 10 schools with `S` values ranging from 0 to 9. Here is a simplified version of the model code (this is just for demonstration; there are more variables used in linear combination):

``````import numpyro
import numpyro.distributions as dist

def model_binomial_hierarchical(M, S):
n_schools = 10  # Number of schools in the training data

# Group priors
μ_bM = numpyro.sample("μ_bM", dist.Normal(0, 1))
σ_bM = numpyro.sample("σ_bM", dist.Exponential(1))

with numpyro.plate("plate_i", n_schools):
bM = numpyro.sample("bM", dist.Normal(μ_bM, σ_bM))

linear_combination = bM[S] * M

with numpyro.plate("data", len(S)):
``````

My concern is how to use the `Predictive` class from NumPyro to make predictions for schools with `S` values outside the training range. For example, if I want to predict for a school with `S=30`, which was not present in the training data, what would be the best approach? Or can I simply use `Predictive` as shown in the documentation example

``````predictive = Predictive(model, posterior_samples=posterior_samples)
y_pred = predictive(rng_key, X)["obs"]
``````

Does the `Predictive` class in NumPyro utilize hierarchical information when making predictions for schools with `S` values outside the training range?

Considering that directly using the trained model and posterior samples with `Predictive` may not provide accurate results for unseen schools (since they weren’t used for the training), I’m curious if the `Predictive` class takes advantage of the hierarchical structure of the model to improve predictions for schools with `S` values outside the training range. Specifically, I would like to understand if the `Predictive` class incorporates hierarchical information from similar schools or leverages the learned coefficients from the training data to enhance predictions for unseen schools.

Any insights on how the `Predictive` class handles hierarchical information or whether there are specific techniques to incorporate hierarchical knowledge into the prediction process would be greatly appreciated.

i think you either have to

• add dummy schools during inference (e.g. by giving them fake data and masking our their likelihood) or;
• use posterior samples directly to compute `linear_combination` for the school of interest
the first case would still use `Predictive`

@martinjankowiak

I apologize if I did not clearly articulate my questions earlier. Essentially, what I am trying to ask is the following:

I fit the model using following data:

``````# dummy training data
obs = ...
M = ...
S = np.random.randint(0, 10, size=100)

kernel = NUTS(model_binomial_hierarchical)
mcmc = MCMC(kernel, num_warmup=3000, num_samples=10000)
random_key, subkey = random.split(random_key)
mcmc.run(subkey, M, S, obs=obs)

posterior_samples = mcmc.get_samples()
print(posterior_samples['bM'].shape) # (10000, 10)
``````

This means that I obtained values of `bM` coefficient for each one of the ten training schools.

``````# predictions on the training data
posterior_predictive = Predictive(model_binomial_hierarchical, posterior_samples=posterior_samples)
posterior_predictions = posterior_predictive(random.PRNGKey(1), M, S)

# dummy test data
M_test = ...
S_test = np.random.randint(10, 15, size=100)

# predictions on the test data
posterior_predictive = Predictive(model_binomial_hierarchical, posterior_samples=posterior_samples)
posterior_predictions = posterior_predictive(random.PRNGKey(1), M_test, S_test)
``````

How come the last line of code doesn’t crash? How come I obtained quite reasonable predictions?

I would either expect it

• to crash because I’m indexing with `S` outside range
• to leverage hierarchical information from the group priors (which I’d like you to either confirm or refute)

The question raised by @causalmodeling about Predictive class and its behavior with out-of-range school indices in `S_test` is intriguing.

I’ve been contemplating a related hypothesis and I’m interested to hear your thoughts on it. Could it be possible that the Predictive class is designed to use the group-level distribution (`dist.Normal(μ_bM, σ_bM)`) when it encounters an out-of-bound index in `S_test`? This would provide a kind of “fallback” when specific information about a school is not available in the training set, making the predictions still reasonable by utilizing our overall understanding of schools.

This seems to align with @causalmodeling’s observations that the code doesn’t crash when faced with unseen schools in `S_test` and even produces reasonable predictions​​. It’s as though the Predictive class is saying, “When we don’t have specific information about a school, we’ll use our overall understanding of schools in general.”

Could you please confirm or refute this? Is the Predictive class designed to leverage the group priors when confronted with out-of-range values in `S_test`?

@causalmodeling does your code crash if `posterior_samples` only has e.g. 3 samples?

i think the indexing in your model isn’t setup to handle the expected vectorization that `Predictive` expects. it’s hard to say since you didn’t provide a complete runnable script but you probably need something like
`linear_combination = bM[..., S] * M`

Here is a toy example (I recommend putting each code block into a separate cell in jupyter notebook to see the plots, etc.)

Create training dataset:

``````from typing import Dict, List

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

def generate_data(n_samples: int, true_params: Dict[str, float], n_normal=5):
random_key = random.PRNGKey(0)

dataset = {}
for i in range(0, n_normal):
random_key, subkey = random.split(random_key)
dataset[f"VAR{i}"] = random.normal(subkey, (n_samples,))

idx = 0
for key, value in true_params.items():
if key == "true_intercept":
linear_combination = true_params["true_intercept"]
continue
var_key = "VAR" + str(idx)
linear_combination += value * dataset[var_key]
idx += 1

# generate total_count with numbers between 20-100
i += 1
random_key, subkey = random.split(random_key)
dataset[f"N_VAR{i}"] = random.randint(subkey, shape=(n_samples,), minval=20, maxval=101)

# obs with binomial distribution
random_key, subkey = random.split(random_key)
binomial_dist = dist.Binomial(total_count=dataset[f"N_VAR{i}"], logits=linear_combination)
dataset[f"N_VAR{i+1}"] = binomial_dist.sample(random_key)

return dataset

def generate_hierarchical_data(
n_samples: int, true_params: Dict[str, List[float]], n_clusters: int, n_normal=5
):
merged_dataset = {}
for i in range(n_clusters):
cluster_params = {key: value[i] for key, value in true_params.items()}
cluster_dataset = generate_data(n_samples, cluster_params, n_normal)

# add cluster id to the dataset
cluster_dataset["ID"] = np.full((n_samples,), i)

# add data from the cluster to the merged dataset
if i == 0:  # if this is the first cluster, just use its data
merged_dataset = cluster_dataset
else:  # otherwise, concatenate the data from this cluster to the existing data
for key in merged_dataset.keys():
merged_dataset[key] = np.concatenate((merged_dataset[key], cluster_dataset[key]))

return merged_dataset

true_params = {
"true_intercept": [1.0, 1.1, 1.0, 1.0, 1.0],
"true_M": [-0.5, -0.4, -0.4, -0.45, -0.5],
"true_D": [0.3, 0.35, 0.3, 0.35, 0.3],
"true_N": [0.2, 0.3, 0.2, 0.3, 0.2]
}

# generate data for 5 schools
dataset = generate_hierarchical_data(10, true_params, n_clusters=5, n_normal=3)

# rename variables - names have to match the names of the model arguments
new_keys = ["M", "D", "N", "N_TOTAL", "N_POINTS", "S"]
old_keys = list(dataset.keys())

for key, n_key in zip(old_keys, new_keys):
dataset[n_key] = dataset.pop(key)

# prepare dataset with features only, pop obs variable
dataset_features = dataset.copy()
N_POINTS = dataset_features.pop("N_POINTS")

# check IDs of schools
print(dataset_features['S'])
``````

Create and train model:

``````def model(M, D, N, S, N_TOTAL, N_POINTS=None):
n_schools = len(np.unique(S))

# group priors
μ_i = numpyro.sample("μ_i", dist.Normal(0, 1))
σ_i = numpyro.sample("σ_i", dist.Exponential(1))

μ_bM = numpyro.sample("μ_bM", dist.Normal(0, 1))
σ_bM = numpyro.sample("σ_bM", dist.Exponential(1))

μ_bD = numpyro.sample("μ_bD", dist.Normal(0, 1))
σ_bD = numpyro.sample("σ_bD", dist.Exponential(1))

μ_bN = numpyro.sample("μ_bN", dist.Normal(0, 1))
σ_bN = numpyro.sample("σ_bN", dist.Exponential(1))

with numpyro.plate("plate_i", n_schools):
i = numpyro.sample("intercept", dist.Normal(μ_i, σ_i))
bM = numpyro.sample("bM", dist.Normal(μ_bM, σ_bM))
bD = numpyro.sample("bD", dist.Normal(μ_bD, σ_bD))
bN = numpyro.sample("bN", dist.Normal(μ_bN, σ_bN))

linear_combination = i[S] + bM[S] * M + bD[S] * D + bN[S] * N

with numpyro.plate("data", len(S)):
numpyro.sample("POINTS_counts", dist.Binomial(total_count=N_TOTAL, logits=linear_combination), obs=N_POINTS)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=5, num_samples=4)
random_key, subkey = random.split(random.PRNGKey(1))
mcmc.run(subkey, **dataset)
mcmc.print_summary()
``````

Get posterior samples

``````posterior_samples = mcmc.get_samples()

# check shape of posterior samples
print(posterior_samples["bM"].shape) # (4, 5)
print(posterior_samples["σ_bN"].shape) # (4,)
``````

Get predictions:

``````predictive = Predictive(model, posterior_samples=posterior_samples, return_sites=("POINTS_counts",))
predictions = predictive(random.PRNGKey(0), **dataset_features)
mean_predictions = np.array(jnp.mean(predictions['POINTS_counts'], axis=0))
mae = np.abs(N_POINTS - mean_predictions)
print("MAE: ", mae.mean())
``````

Plot predictions:

``````plt.scatter(N_POINTS, mean_predictions)
plt.xlabel('True')
plt.ylabel('Predicted')
plt.title('True vs predicted points for training schools')
plt.show()
``````

Create test dataset:

``````true_params = {
"true_intercept": [1.0],
"true_M": [-0.55],
"true_D": [0.3],
"true_N": [0.4]
}

dataset_test = generate_hierarchical_data(10, true_params, n_clusters=1, n_normal=3)

# rename variables - names have to match the names of the model arguments
new_keys = ["M", "D", "N", "N_TOTAL", "N_POINTS", "S"]
old_keys = list(dataset_test.keys())

for key, n_key in zip(old_keys, new_keys):
dataset_test[n_key] = dataset_test.pop(key)

# prepare dataset with features only, pop obs variable
dataset_features_test = dataset_test.copy()
N_POINTS_test = dataset_features_test.pop("N_POINTS")

# assign unique ID to new test school
dataset_features_test['S'] = (dataset_features_test['S']+1)*10
print(dataset_features_test['S'])
``````

Get predictions:

``````predictive = Predictive(model, posterior_samples=posterior_samples, return_sites=("POINTS_counts",))
predictions = predictive(random.PRNGKey(0), **dataset_features_test)
mean_predictions_test = np.array(jnp.mean(predictions['POINTS_counts'], axis=0))
mae = np.abs(N_POINTS_test - mean_predictions_test)
print("MAE: ", mae.mean())
``````

Plot predictions:

``````plt.scatter(N_POINTS_test, mean_predictions_test)
plt.xlabel('True')
plt.ylabel('Predicted')
plt.title('True vs predicted points for test school')
plt.show()
``````

I tried what you advised and had only 4 posterior samples and my code didn’t crash. Also, predictions for test school are not much worse than those for training schools.

you’re running into this jax sharp bit: 🔪 JAX - The Sharp Bits 🔪 — JAX documentation

i guess it’ll give you predictions using school `0` latent variables

Thank you very much for the explanation. I would never have thought of this.