Predicting for Schools Outside the Training Range using NumPyro's Predictive Class

Hello everyone,

I have a hierarchical model written in NumPyro that predicts the academic performance of schools based on various factors. The model is trained using data from 10 schools with S values ranging from 0 to 9. Here is a simplified version of the model code (this is just for demonstration; there are more variables used in linear combination):

import numpyro
import numpyro.distributions as dist

def model_binomial_hierarchical(M, S):
    n_schools = 10  # Number of schools in the training data

    # Group priors
    μ_bM = numpyro.sample("μ_bM", dist.Normal(0, 1))
    σ_bM = numpyro.sample("σ_bM", dist.Exponential(1))

    with numpyro.plate("plate_i", n_schools):
        bM = numpyro.sample("bM", dist.Normal(μ_bM, σ_bM))

    linear_combination = bM[S] * M

    with numpyro.plate("data", len(S)):
        numpyro.sample("academic_performance", dist.Binomial(total_count=100, logits=linear_combination))

My concern is how to use the Predictive class from NumPyro to make predictions for schools with S values outside the training range. For example, if I want to predict for a school with S=30, which was not present in the training data, what would be the best approach? Or can I simply use Predictive as shown in the documentation example

predictive = Predictive(model, posterior_samples=posterior_samples)
y_pred = predictive(rng_key, X)["obs"]

Does the Predictive class in NumPyro utilize hierarchical information when making predictions for schools with S values outside the training range?

Considering that directly using the trained model and posterior samples with Predictive may not provide accurate results for unseen schools (since they weren’t used for the training), I’m curious if the Predictive class takes advantage of the hierarchical structure of the model to improve predictions for schools with S values outside the training range. Specifically, I would like to understand if the Predictive class incorporates hierarchical information from similar schools or leverages the learned coefficients from the training data to enhance predictions for unseen schools.

Any insights on how the Predictive class handles hierarchical information or whether there are specific techniques to incorporate hierarchical knowledge into the prediction process would be greatly appreciated.

Thank you in advance for your help!

i think you either have to

  • add dummy schools during inference (e.g. by giving them fake data and masking our their likelihood) or;
  • use posterior samples directly to compute linear_combination for the school of interest
    the first case would still use Predictive

@martinjankowiak

I apologize if I did not clearly articulate my questions earlier. Essentially, what I am trying to ask is the following:

I fit the model using following data:

# dummy training data
obs = ...
M = ...
S = np.random.randint(0, 10, size=100)

kernel = NUTS(model_binomial_hierarchical)
mcmc = MCMC(kernel, num_warmup=3000, num_samples=10000)
random_key, subkey = random.split(random_key)
mcmc.run(subkey, M, S, obs=obs)

posterior_samples = mcmc.get_samples()
print(posterior_samples['bM'].shape) # (10000, 10)

This means that I obtained values of bM coefficient for each one of the ten training schools.

# predictions on the training data 
posterior_predictive = Predictive(model_binomial_hierarchical, posterior_samples=posterior_samples)
posterior_predictions = posterior_predictive(random.PRNGKey(1), M, S)

# dummy test data
M_test = ...
S_test = np.random.randint(10, 15, size=100)

# predictions on the test data 
posterior_predictive = Predictive(model_binomial_hierarchical, posterior_samples=posterior_samples)
posterior_predictions = posterior_predictive(random.PRNGKey(1), M_test, S_test)

How come the last line of code doesn’t crash? How come I obtained quite reasonable predictions?

I would either expect it

  • to crash because I’m indexing with S outside range
  • to leverage hierarchical information from the group priors (which I’d like you to either confirm or refute)

The question raised by @causalmodeling about Predictive class and its behavior with out-of-range school indices in S_test is intriguing.

I’ve been contemplating a related hypothesis and I’m interested to hear your thoughts on it. Could it be possible that the Predictive class is designed to use the group-level distribution (dist.Normal(μ_bM, σ_bM)) when it encounters an out-of-bound index in S_test? This would provide a kind of “fallback” when specific information about a school is not available in the training set, making the predictions still reasonable by utilizing our overall understanding of schools.

This seems to align with @causalmodeling’s observations that the code doesn’t crash when faced with unseen schools in S_test and even produces reasonable predictions​​. It’s as though the Predictive class is saying, “When we don’t have specific information about a school, we’ll use our overall understanding of schools in general.”

Could you please confirm or refute this? Is the Predictive class designed to leverage the group priors when confronted with out-of-range values in S_test?

Thank you for your assistance.

@causalmodeling does your code crash if posterior_samples only has e.g. 3 samples?

i think the indexing in your model isn’t setup to handle the expected vectorization that Predictive expects. it’s hard to say since you didn’t provide a complete runnable script but you probably need something like
linear_combination = bM[..., S] * M

Here is a toy example (I recommend putting each code block into a separate cell in jupyter notebook to see the plots, etc.)

Create training dataset:

from typing import Dict, List

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

def generate_data(n_samples: int, true_params: Dict[str, float], n_normal=5):
    random_key = random.PRNGKey(0)

    dataset = {}
    for i in range(0, n_normal):
        random_key, subkey = random.split(random_key)
        dataset[f"VAR{i}"] = random.normal(subkey, (n_samples,))

    idx = 0
    for key, value in true_params.items():
        if key == "true_intercept":
            linear_combination = true_params["true_intercept"]
            continue
        var_key = "VAR" + str(idx)
        linear_combination += value * dataset[var_key]
        idx += 1

    # generate total_count with numbers between 20-100
    i += 1
    random_key, subkey = random.split(random_key)
    dataset[f"N_VAR{i}"] = random.randint(subkey, shape=(n_samples,), minval=20, maxval=101)

    # obs with binomial distribution
    random_key, subkey = random.split(random_key)
    binomial_dist = dist.Binomial(total_count=dataset[f"N_VAR{i}"], logits=linear_combination)
    dataset[f"N_VAR{i+1}"] = binomial_dist.sample(random_key)

    return dataset


def generate_hierarchical_data(
    n_samples: int, true_params: Dict[str, List[float]], n_clusters: int, n_normal=5
):
    merged_dataset = {}
    for i in range(n_clusters):
        cluster_params = {key: value[i] for key, value in true_params.items()}
        cluster_dataset = generate_data(n_samples, cluster_params, n_normal)

        # add cluster id to the dataset
        cluster_dataset["ID"] = np.full((n_samples,), i)

        # add data from the cluster to the merged dataset
        if i == 0:  # if this is the first cluster, just use its data
            merged_dataset = cluster_dataset
        else:  # otherwise, concatenate the data from this cluster to the existing data
            for key in merged_dataset.keys():
                merged_dataset[key] = np.concatenate((merged_dataset[key], cluster_dataset[key]))

    return merged_dataset


true_params = {
    "true_intercept": [1.0, 1.1, 1.0, 1.0, 1.0],
    "true_M": [-0.5, -0.4, -0.4, -0.45, -0.5],
    "true_D": [0.3, 0.35, 0.3, 0.35, 0.3],
    "true_N": [0.2, 0.3, 0.2, 0.3, 0.2]
}

# generate data for 5 schools
dataset = generate_hierarchical_data(10, true_params, n_clusters=5, n_normal=3)

# rename variables - names have to match the names of the model arguments
new_keys = ["M", "D", "N", "N_TOTAL", "N_POINTS", "S"]
old_keys = list(dataset.keys())

for key, n_key in zip(old_keys, new_keys):
    dataset[n_key] = dataset.pop(key)

# prepare dataset with features only, pop obs variable
dataset_features = dataset.copy()
N_POINTS = dataset_features.pop("N_POINTS")

# check IDs of schools
print(dataset_features['S'])

Create and train model:

def model(M, D, N, S, N_TOTAL, N_POINTS=None):
    n_schools = len(np.unique(S))

    # group priors
    μ_i = numpyro.sample("μ_i", dist.Normal(0, 1))
    σ_i = numpyro.sample("σ_i", dist.Exponential(1))

    μ_bM = numpyro.sample("μ_bM", dist.Normal(0, 1))
    σ_bM = numpyro.sample("σ_bM", dist.Exponential(1))

    μ_bD = numpyro.sample("μ_bD", dist.Normal(0, 1))
    σ_bD = numpyro.sample("σ_bD", dist.Exponential(1))

    μ_bN = numpyro.sample("μ_bN", dist.Normal(0, 1))
    σ_bN = numpyro.sample("σ_bN", dist.Exponential(1))

    with numpyro.plate("plate_i", n_schools):
        i = numpyro.sample("intercept", dist.Normal(μ_i, σ_i))
        bM = numpyro.sample("bM", dist.Normal(μ_bM, σ_bM))
        bD = numpyro.sample("bD", dist.Normal(μ_bD, σ_bD))
        bN = numpyro.sample("bN", dist.Normal(μ_bN, σ_bN))

    linear_combination = i[S] + bM[S] * M + bD[S] * D + bN[S] * N

    with numpyro.plate("data", len(S)):
        numpyro.sample("POINTS_counts", dist.Binomial(total_count=N_TOTAL, logits=linear_combination), obs=N_POINTS)


kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=5, num_samples=4)
random_key, subkey = random.split(random.PRNGKey(1))
mcmc.run(subkey, **dataset)
mcmc.print_summary()

Get posterior samples

posterior_samples = mcmc.get_samples()

# check shape of posterior samples
print(posterior_samples["bM"].shape) # (4, 5)
print(posterior_samples["σ_bN"].shape) # (4,)

Get predictions:

predictive = Predictive(model, posterior_samples=posterior_samples, return_sites=("POINTS_counts",))
predictions = predictive(random.PRNGKey(0), **dataset_features)
mean_predictions = np.array(jnp.mean(predictions['POINTS_counts'], axis=0))
mae = np.abs(N_POINTS - mean_predictions)
print("MAE: ", mae.mean())

Plot predictions:

plt.scatter(N_POINTS, mean_predictions)
plt.xlabel('True')
plt.ylabel('Predicted')
plt.title('True vs predicted points for training schools')
plt.show()

Create test dataset:

true_params = {
    "true_intercept": [1.0],
    "true_M": [-0.55],
    "true_D": [0.3],
    "true_N": [0.4]
}

dataset_test = generate_hierarchical_data(10, true_params, n_clusters=1, n_normal=3)

# rename variables - names have to match the names of the model arguments
new_keys = ["M", "D", "N", "N_TOTAL", "N_POINTS", "S"]
old_keys = list(dataset_test.keys())

for key, n_key in zip(old_keys, new_keys):
    dataset_test[n_key] = dataset_test.pop(key)

# prepare dataset with features only, pop obs variable
dataset_features_test = dataset_test.copy()
N_POINTS_test = dataset_features_test.pop("N_POINTS")

# assign unique ID to new test school
dataset_features_test['S'] = (dataset_features_test['S']+1)*10
print(dataset_features_test['S'])

Get predictions:

predictive = Predictive(model, posterior_samples=posterior_samples, return_sites=("POINTS_counts",))
predictions = predictive(random.PRNGKey(0), **dataset_features_test)
mean_predictions_test = np.array(jnp.mean(predictions['POINTS_counts'], axis=0))
mae = np.abs(N_POINTS_test - mean_predictions_test)
print("MAE: ", mae.mean())

Plot predictions:

plt.scatter(N_POINTS_test, mean_predictions_test)
plt.xlabel('True')
plt.ylabel('Predicted')
plt.title('True vs predicted points for test school')
plt.show()

I tried what you advised and had only 4 posterior samples and my code didn’t crash. Also, predictions for test school are not much worse than those for training schools.

you’re running into this jax sharp bit: 🔪 JAX - The Sharp Bits 🔪 — JAX documentation

i guess it’ll give you predictions using school 0 latent variables

Thank you very much for the explanation. I would never have thought of this.