Porting a model from BUGS to numpyro

It has been a while for me to do modeling. Here is my attempt (I haven’t tested the code yet - just wanna give you a template for common practices using NumPyro). As you will see, the code is very similar to R code. I just split random walks into drift terms to make things easier to reparam.

reparam_config = {
    k: LocScaleReparam(0) for k in
    ["alpha_s1", "beta_s1", "alpha_s2", "beta_s2", "alpha_s3", "beta_s3",
     "alpha_age_drift", "xi", "nu_drift", "gamma_drift"]}

@numpyro.handlers.reparam(config=reparam_config)
def model(...):
  # plates
  space_plate = numpyro.plate("space", N_s3, dim=-3)
  age_plate = numpyro.plate("age_groups", N_age, dim=-2)
  year_plate = numpyro.plate("year", N_t - 1, dim=-1)

  # global terms
  sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0, 2))
  sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0, 2))
  sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0, 2))
  sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0, 2))
  sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0, 2))
  sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0, 2))
  sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0, 2))
  sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0, 2))
  sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0, 2))
  sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0, 2))
  sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0, 2))
  theta = numpyro.sample("theta", dist.Exponential(1))

  # spatial hierarchy
  with numpyro.plate("s1", N_s1, dim=-3):
    alpha_s1 = numpyro.sample("alpha_s1", dist.Normal(0, sigma_alpha_s1))
    beta_s1 = numpyro.sample("beta_s1", dist.Normal(0, sigma_beta_s1))
  with numpyro.plate("s2", N_s2, dim=-3):
    # FIXME: hier1_id should have length N_s2, hier2_id should have length N_s3
    alpha_s2 = numpyro.sample("alpha_s2", dist.Normal(alpha_s1[hier1_id], sigma_alpha_s2))
    beta_s2 = numpyro.sample("beta_s2", dist.Normal(beta_s1[hier1_id], sigma_beta_s2))
  with space_plate:
    alpha_s3 = numpyro.sample("alpha_s3", dist.Normal(alpha_s2[hier2_id], sigma_alpha_s3))
    beta_s3 = numpyro.sample("beta_s3", dist.Normal(beta_s2[hier2_id], sigma_beta_s3))

  # age
  with age_plate:
    alpha_age_drift_scale = jnp.pad(
        jnp.broadcast_to(sigma_alpha_age, N_age - 1), (1, 0), constant_values=100.)[:, None]
    alpha_age_drift = numpyro.sample("alpha_age_drift", dist.Normal(0, alpha_age_drift_scale))
    alpha_age = jnp.cumsum(alpha_age_drift, -2)
    beta_age_drift_scale = jnp.pad(
        jnp.broadcast_to(sigma_beta_age, N_age - 1), [(1, 0)], constant_values=100.)[:, None]
    beta_age_drift = numpyro.sample("beta_age_drift", dist.Normal(0, beta_age_drift_scale))
    beta_age = jnp.cumsum(beta_age_drift, -2)

  # age-space interactions
  with age_plate, space_plate:
    xi = numpyro.sample("xi", dist.Normal(alpha_age + alpha_s3, sigma_xi))

  # space-time random walk
  with space_plate, year_plate:
    nu_drift = numpyro.sample("nu_drift", dist.Normal(beta_s3, sigma_nu))
    nu = jnp.pad(jnp.cumsum(nu_drift, -1), [(0, 0), (0, 0), (1, 0)])

  # age-time random walk
  with age_plate, year_plate:
    gamma_drift = numpyro.sample("gamma_drift", dist.Normal(beta_age, sigma_gamma))
    gamma = jnp.pad(jnp.cumsum(gamma_drift, -1), [(0, 0), (1, 0)])

  latent_rate = xi + nu + gamma
  # likelihood
  with numpyro.plate("N", N):
    mu_logit = latent_rate[age_id, hier3_id, year_id]
    mu = numpyro.deterministic("mu", expit(mu_logit))
    numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)