Thanks @martinjankowiak !
Indeed, if I do that, it compiles, but then I get one different values of intercept_city
for each row of data.
What I was hoping to be able to do was that each city would have its own intercept, would would be shrunk towards its country’s intercept, which in turn would be shrunk towards some global intercept.
Not sure if what I’m asking is possible or makes sense, sorry
Essentially, I could do:
intercept_city = numpyro.sample('intercept_city', dist.Normal(0, 1).expand([n_countries, n_cities]))
and then each city would be treated separately. But I’d like that, within each country, the cities should share some common hyperprior
Something else I’ve tried is
with numpyro.plate('country', n_countries):
intercept_country = numpyro.sample('intercept_country', dist.Normal(0, 1))
with numpyro.plate('city', n_cities):
intercept_city = numpyro.sample('intercept_city', dist.Normal(0, 1))
mu = intercept_country[country] + intercept_city[country, city]
but then there’s no hierarchy