Hierarchical model with two levels

Here’s some code I’ve written, but I’m struggling to make sense of how to fix it.

There’s some observed data from two countries, from two cities each. So I’d like to make a hierarchical model with two levels:

  • each city can have its own mean and std
  • each city within the same country shares hyperpriors
  • each country shares global hyperpriors

Here’s a complete reproducible example:

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
import pandas as pd
import scipy.stats as stats
from jax import random

numpyro.set_host_device_count(2)

vietnam_mean = 5
malaysia_mean = 4
hanoi = stats.norm(vietnam_mean, 2).rvs(size=10)
haiphong = stats.norm(vietnam_mean, 2).rvs(size=20)
kuala_lumpur = stats.norm(malaysia_mean, 3).rvs(size=15)
kajang = stats.norm(malaysia_mean, 1).rvs(size=10)

data = pd.DataFrame(columns=["values", "city", "country"])
for city_name, country_name, values in (
    ("hanoi", "vietnam", hanoi),
    ("haiphong", "vietnam", haiphong),
    ("kuala_lumpur", "malasyia", kuala_lumpur),
    ("kajang", "malasyia", kajang),
):
    data = pd.concat(
        [
            data,
            pd.DataFrame(
                {"values": values, "city": city_name, "country": country_name}
            ),
        ]
    )


def model(value, city, country):
    n_cities = len(np.unique(city))
    n_countries = len(np.unique(country))

    sigma = numpyro.sample("sigma", dist.HalfNormal(1))

    vietnam_mean = numpyro.sample("vietnam_mean", dist.Normal(0, 1))
    vietnam_std = numpyro.sample("vietnam_std", dist.HalfNormal(1))
    malaysia_mean = numpyro.sample("malaysia_mean", dist.Normal(0, 1))
    malaysia_std = numpyro.sample("malaysia_std", dist.HalfNormal(1))

    country_mean = jnp.where(country == 0, malaysia_mean, vietnam_mean)
    country_std = jnp.where(country == 0, malaysia_std, vietnam_std)
    with numpyro.plate("cities", n_cities):
        intercept_city = numpyro.sample(
            "intercept_city", dist.Normal(country_mean, country_std)
        )

    mu = intercept_city[city]

    numpyro.sample("obs", dist.Normal(mu, sigma), obs=value)


mcmc = infer.MCMC(infer.NUTS(model), num_chains=2, num_samples=500, num_warmup=500)
mcmc.run(
    random.PRNGKey(0),
    value=data["values"].to_numpy(),
    city=data["city"].astype("category").cat.codes.to_numpy(),
    country=data["country"].astype("category").cat.codes.to_numpy(),
)
mcmc.print_summary()

If I run this, I get

ValueError: Incompatible shapes for broadcasting: ((4,), (55,))

because country_mean is of shape (55,).

How should I modify the example such that the 2-level hierarchy can work?

i guess you probably want something like

with numpyro.plate("cities", len(city)):  # no unique

either that or you need to change your country_mean indexing logic so that the result is of length n_cities

Thanks @martinjankowiak !

Indeed, if I do that, it compiles, but then I get one different values of intercept_city for each row of data.

What I was hoping to be able to do was that each city would have its own intercept, would would be shrunk towards its country’s intercept, which in turn would be shrunk towards some global intercept.

Not sure if what I’m asking is possible or makes sense, sorry

Essentially, I could do:

    intercept_city = numpyro.sample('intercept_city', dist.Normal(0, 1).expand([n_countries, n_cities]))

and then each city would be treated separately. But I’d like that, within each country, the cities should share some common hyperprior

Something else I’ve tried is

    with numpyro.plate('country', n_countries):
        intercept_country = numpyro.sample('intercept_country', dist.Normal(0, 1))
        with numpyro.plate('city', n_cities):
            intercept_city = numpyro.sample('intercept_city', dist.Normal(0, 1))

    mu = intercept_country[country] + intercept_city[country, city]

but then there’s no hierarchy

Trying this again, this seems to work:

    with numpyro.plate('country', n_countries):
        country_mean = numpyro.sample('country_mean', dist.Normal(0, 5))
        country_sd = numpyro.sample('country_sd', dist.HalfNormal(5))
        with numpyro.plate('city', n_cities):
            intercept_city = numpyro.sample('intercept_city', dist.Normal(country_mean, country_sd))

    mu = intercept_city[city, country]

still trying to figure out if it’s doing what I was expecting though

Have you confirmed that this gives you what you expect? I’m stuck with the same issue with the model not producing what I intend. When I tried this, I got a cartesian product of city and country (i.e. parameters for non-existent city, country pairs would be generated in the model, such as hanoi, malaysia and kuala_lumpur, vietnam).

I figured it out:

    city_to_country_lookup = data.groupby('city')['country'].first().values
    with numpyro.plate('country', n_countries):
        country_mean = numpyro.sample('country_mean', dist.Normal(0, 5))
        country_sd = numpyro.sample('country_sd', dist.HalfNormal(5))
    with numpyro.plate('city', n_cities):
        intercept_city = numpyro.sample('intercept_city', dist.Normal(country_mean[city_to_country_lookup], country_sd[city_to_country_lookup]))

    mu = intercept_city[city]
2 Likes