Here’s some code I’ve written, but I’m struggling to make sense of how to fix it.
There’s some observed data from two countries, from two cities each. So I’d like to make a hierarchical model with two levels:
- each city can have its own mean and std
- each city within the same country shares hyperpriors
- each country shares global hyperpriors
Here’s a complete reproducible example:
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
import pandas as pd
import scipy.stats as stats
from jax import random
numpyro.set_host_device_count(2)
vietnam_mean = 5
malaysia_mean = 4
hanoi = stats.norm(vietnam_mean, 2).rvs(size=10)
haiphong = stats.norm(vietnam_mean, 2).rvs(size=20)
kuala_lumpur = stats.norm(malaysia_mean, 3).rvs(size=15)
kajang = stats.norm(malaysia_mean, 1).rvs(size=10)
data = pd.DataFrame(columns=["values", "city", "country"])
for city_name, country_name, values in (
("hanoi", "vietnam", hanoi),
("haiphong", "vietnam", haiphong),
("kuala_lumpur", "malasyia", kuala_lumpur),
("kajang", "malasyia", kajang),
):
data = pd.concat(
[
data,
pd.DataFrame(
{"values": values, "city": city_name, "country": country_name}
),
]
)
def model(value, city, country):
n_cities = len(np.unique(city))
n_countries = len(np.unique(country))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
vietnam_mean = numpyro.sample("vietnam_mean", dist.Normal(0, 1))
vietnam_std = numpyro.sample("vietnam_std", dist.HalfNormal(1))
malaysia_mean = numpyro.sample("malaysia_mean", dist.Normal(0, 1))
malaysia_std = numpyro.sample("malaysia_std", dist.HalfNormal(1))
country_mean = jnp.where(country == 0, malaysia_mean, vietnam_mean)
country_std = jnp.where(country == 0, malaysia_std, vietnam_std)
with numpyro.plate("cities", n_cities):
intercept_city = numpyro.sample(
"intercept_city", dist.Normal(country_mean, country_std)
)
mu = intercept_city[city]
numpyro.sample("obs", dist.Normal(mu, sigma), obs=value)
mcmc = infer.MCMC(infer.NUTS(model), num_chains=2, num_samples=500, num_warmup=500)
mcmc.run(
random.PRNGKey(0),
value=data["values"].to_numpy(),
city=data["city"].astype("category").cat.codes.to_numpy(),
country=data["country"].astype("category").cat.codes.to_numpy(),
)
mcmc.print_summary()
If I run this, I get
ValueError: Incompatible shapes for broadcasting: ((4,), (55,))
because country_mean
is of shape (55,)
.
How should I modify the example such that the 2-level hierarchy can work?