It has been a while for me to do modeling. Here is my attempt (I haven’t tested the code yet - just wanna give you a template for common practices using NumPyro). As you will see, the code is very similar to R code. I just split random walks into drift terms to make things easier to reparam.
reparam_config = {
k: LocScaleReparam(0) for k in
["alpha_s1", "beta_s1", "alpha_s2", "beta_s2", "alpha_s3", "beta_s3",
"alpha_age_drift", "xi", "nu_drift", "gamma_drift"]}
@numpyro.handlers.reparam(config=reparam_config)
def model(...):
# plates
space_plate = numpyro.plate("space", N_s3, dim=-3)
age_plate = numpyro.plate("age_groups", N_age, dim=-2)
year_plate = numpyro.plate("year", N_t - 1, dim=-1)
# global terms
sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0, 2))
sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0, 2))
sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0, 2))
sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0, 2))
sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0, 2))
sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0, 2))
sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0, 2))
sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0, 2))
sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0, 2))
sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0, 2))
sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0, 2))
theta = numpyro.sample("theta", dist.Exponential(1))
# spatial hierarchy
with numpyro.plate("s1", N_s1, dim=-3):
alpha_s1 = numpyro.sample("alpha_s1", dist.Normal(0, sigma_alpha_s1))
beta_s1 = numpyro.sample("beta_s1", dist.Normal(0, sigma_beta_s1))
with numpyro.plate("s2", N_s2, dim=-3):
# FIXME: hier1_id should have length N_s2, hier2_id should have length N_s3
alpha_s2 = numpyro.sample("alpha_s2", dist.Normal(alpha_s1[hier1_id], sigma_alpha_s2))
beta_s2 = numpyro.sample("beta_s2", dist.Normal(beta_s1[hier1_id], sigma_beta_s2))
with space_plate:
alpha_s3 = numpyro.sample("alpha_s3", dist.Normal(alpha_s2[hier2_id], sigma_alpha_s3))
beta_s3 = numpyro.sample("beta_s3", dist.Normal(beta_s2[hier2_id], sigma_beta_s3))
# age
with age_plate:
alpha_age_drift_scale = jnp.pad(
jnp.broadcast_to(sigma_alpha_age, N_age - 1), (1, 0), constant_values=100.)[:, None]
alpha_age_drift = numpyro.sample("alpha_age_drift", dist.Normal(0, alpha_age_drift_scale))
alpha_age = jnp.cumsum(alpha_age_drift, -2)
beta_age_drift_scale = jnp.pad(
jnp.broadcast_to(sigma_beta_age, N_age - 1), [(1, 0)], constant_values=100.)[:, None]
beta_age_drift = numpyro.sample("beta_age_drift", dist.Normal(0, beta_age_drift_scale))
beta_age = jnp.cumsum(beta_age_drift, -2)
# age-space interactions
with age_plate, space_plate:
xi = numpyro.sample("xi", dist.Normal(alpha_age + alpha_s3, sigma_xi))
# space-time random walk
with space_plate, year_plate:
nu_drift = numpyro.sample("nu_drift", dist.Normal(beta_s3, sigma_nu))
nu = jnp.pad(jnp.cumsum(nu_drift, -1), [(0, 0), (0, 0), (1, 0)])
# age-time random walk
with age_plate, year_plate:
gamma_drift = numpyro.sample("gamma_drift", dist.Normal(beta_age, sigma_gamma))
gamma = jnp.pad(jnp.cumsum(gamma_drift, -1), [(0, 0), (1, 0)])
latent_rate = xi + nu + gamma
# likelihood
with numpyro.plate("N", N):
mu_logit = latent_rate[age_id, hier3_id, year_id]
mu = numpyro.deterministic("mu", expit(mu_logit))
numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)