Here is a toy example (I recommend putting each code block into a separate cell in jupyter notebook to see the plots, etc.)

Create training dataset:

```
from typing import Dict, List
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive
def generate_data(n_samples: int, true_params: Dict[str, float], n_normal=5):
random_key = random.PRNGKey(0)
dataset = {}
for i in range(0, n_normal):
random_key, subkey = random.split(random_key)
dataset[f"VAR{i}"] = random.normal(subkey, (n_samples,))
idx = 0
for key, value in true_params.items():
if key == "true_intercept":
linear_combination = true_params["true_intercept"]
continue
var_key = "VAR" + str(idx)
linear_combination += value * dataset[var_key]
idx += 1
# generate total_count with numbers between 20-100
i += 1
random_key, subkey = random.split(random_key)
dataset[f"N_VAR{i}"] = random.randint(subkey, shape=(n_samples,), minval=20, maxval=101)
# obs with binomial distribution
random_key, subkey = random.split(random_key)
binomial_dist = dist.Binomial(total_count=dataset[f"N_VAR{i}"], logits=linear_combination)
dataset[f"N_VAR{i+1}"] = binomial_dist.sample(random_key)
return dataset
def generate_hierarchical_data(
n_samples: int, true_params: Dict[str, List[float]], n_clusters: int, n_normal=5
):
merged_dataset = {}
for i in range(n_clusters):
cluster_params = {key: value[i] for key, value in true_params.items()}
cluster_dataset = generate_data(n_samples, cluster_params, n_normal)
# add cluster id to the dataset
cluster_dataset["ID"] = np.full((n_samples,), i)
# add data from the cluster to the merged dataset
if i == 0: # if this is the first cluster, just use its data
merged_dataset = cluster_dataset
else: # otherwise, concatenate the data from this cluster to the existing data
for key in merged_dataset.keys():
merged_dataset[key] = np.concatenate((merged_dataset[key], cluster_dataset[key]))
return merged_dataset
true_params = {
"true_intercept": [1.0, 1.1, 1.0, 1.0, 1.0],
"true_M": [-0.5, -0.4, -0.4, -0.45, -0.5],
"true_D": [0.3, 0.35, 0.3, 0.35, 0.3],
"true_N": [0.2, 0.3, 0.2, 0.3, 0.2]
}
# generate data for 5 schools
dataset = generate_hierarchical_data(10, true_params, n_clusters=5, n_normal=3)
# rename variables - names have to match the names of the model arguments
new_keys = ["M", "D", "N", "N_TOTAL", "N_POINTS", "S"]
old_keys = list(dataset.keys())
for key, n_key in zip(old_keys, new_keys):
dataset[n_key] = dataset.pop(key)
# prepare dataset with features only, pop obs variable
dataset_features = dataset.copy()
N_POINTS = dataset_features.pop("N_POINTS")
# check IDs of schools
print(dataset_features['S'])
```

Create and train model:

```
def model(M, D, N, S, N_TOTAL, N_POINTS=None):
n_schools = len(np.unique(S))
# group priors
μ_i = numpyro.sample("μ_i", dist.Normal(0, 1))
σ_i = numpyro.sample("σ_i", dist.Exponential(1))
μ_bM = numpyro.sample("μ_bM", dist.Normal(0, 1))
σ_bM = numpyro.sample("σ_bM", dist.Exponential(1))
μ_bD = numpyro.sample("μ_bD", dist.Normal(0, 1))
σ_bD = numpyro.sample("σ_bD", dist.Exponential(1))
μ_bN = numpyro.sample("μ_bN", dist.Normal(0, 1))
σ_bN = numpyro.sample("σ_bN", dist.Exponential(1))
with numpyro.plate("plate_i", n_schools):
i = numpyro.sample("intercept", dist.Normal(μ_i, σ_i))
bM = numpyro.sample("bM", dist.Normal(μ_bM, σ_bM))
bD = numpyro.sample("bD", dist.Normal(μ_bD, σ_bD))
bN = numpyro.sample("bN", dist.Normal(μ_bN, σ_bN))
linear_combination = i[S] + bM[S] * M + bD[S] * D + bN[S] * N
with numpyro.plate("data", len(S)):
numpyro.sample("POINTS_counts", dist.Binomial(total_count=N_TOTAL, logits=linear_combination), obs=N_POINTS)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=5, num_samples=4)
random_key, subkey = random.split(random.PRNGKey(1))
mcmc.run(subkey, **dataset)
mcmc.print_summary()
```

Get posterior samples

```
posterior_samples = mcmc.get_samples()
# check shape of posterior samples
print(posterior_samples["bM"].shape) # (4, 5)
print(posterior_samples["σ_bN"].shape) # (4,)
```

Get predictions:

```
predictive = Predictive(model, posterior_samples=posterior_samples, return_sites=("POINTS_counts",))
predictions = predictive(random.PRNGKey(0), **dataset_features)
mean_predictions = np.array(jnp.mean(predictions['POINTS_counts'], axis=0))
mae = np.abs(N_POINTS - mean_predictions)
print("MAE: ", mae.mean())
```

Plot predictions:

```
plt.scatter(N_POINTS, mean_predictions)
plt.xlabel('True')
plt.ylabel('Predicted')
plt.title('True vs predicted points for training schools')
plt.show()
```

Create test dataset:

```
true_params = {
"true_intercept": [1.0],
"true_M": [-0.55],
"true_D": [0.3],
"true_N": [0.4]
}
dataset_test = generate_hierarchical_data(10, true_params, n_clusters=1, n_normal=3)
# rename variables - names have to match the names of the model arguments
new_keys = ["M", "D", "N", "N_TOTAL", "N_POINTS", "S"]
old_keys = list(dataset_test.keys())
for key, n_key in zip(old_keys, new_keys):
dataset_test[n_key] = dataset_test.pop(key)
# prepare dataset with features only, pop obs variable
dataset_features_test = dataset_test.copy()
N_POINTS_test = dataset_features_test.pop("N_POINTS")
# assign unique ID to new test school
dataset_features_test['S'] = (dataset_features_test['S']+1)*10
print(dataset_features_test['S'])
```

Get predictions:

```
predictive = Predictive(model, posterior_samples=posterior_samples, return_sites=("POINTS_counts",))
predictions = predictive(random.PRNGKey(0), **dataset_features_test)
mean_predictions_test = np.array(jnp.mean(predictions['POINTS_counts'], axis=0))
mae = np.abs(N_POINTS_test - mean_predictions_test)
print("MAE: ", mae.mean())
```

Plot predictions:

```
plt.scatter(N_POINTS_test, mean_predictions_test)
plt.xlabel('True')
plt.ylabel('Predicted')
plt.title('True vs predicted points for test school')
plt.show()
```

I tried what you advised and had only 4 posterior samples and my code didn’t crash. Also, predictions for test school are not much worse than those for training schools.