Porting a model from BUGS to numpyro

Hi there,

I’m translating a model I wrote in BUGS into numpyro and I need some help making it correct and efficient. The model produces mortality estimates for different age groups over space and time. The model uses the following effects to form a logit-transformed death rate:

  • global intercept and slope (alpha0, beta0)
  • spatial intercepts and slopes (alpha_s, beta_s). These are produced from a three tier nested administrative hierarchy. The s1 terms represent the largest geography and are made up of s2 units, which are made up of s3 units – think: state → county → tract. The s2 terms are centred on the s1 terms, and the s3 terms are centred on the s2 terms.
  • age terms (alpha_age, beta_age). Random walk over ages to produce a curve representing higher infant and higher older age mortality
  • age-space interaction (xi) allowing different spatial units to have different age curves
  • space/age random walks (nu and gamma) allowing different spatial units/age groups to have non-linear temporal patters

We map the death rate to death counts using a beta-binomial likelihood with known populations (i.e number of trials).

More detail on the model can be found in the paper. The model inference was performed in nimble, which compiles the model in C++ and runs MCMC. The BUGS model is below

model{
# global terms
alpha0 ~ dnorm(0, 0.00001)
beta0  ~ dnorm(0, 0.00001)

# spatial hierarchy
for(s1 in 1:N_s1){
  alpha_s1[s1] ~ dnorm(0, sd = sigma_alpha_s1)
  beta_s1[s1] ~ dnorm(0, sd = sigma_beta_s1)
}
sigma_alpha_s1 ~ dunif(0,2)
sigma_beta_s1  ~ dunif(0,2)

for(s2 in 1:N_s2){
  alpha_s2[s2] ~ dnorm(alpha_s1[grid.lookup.s2[s2, 2]], sd = sigma_alpha_s2) # centred on s1 terms
  beta_s2[s2] ~ dnorm(beta_s1[grid.lookup.s2[s2, 2]], sd = sigma_beta_s2)
}
sigma_alpha_s2 ~ dunif(0,2)
sigma_beta_s2  ~ dunif(0,2)

for(s in 1:N_space){ # s = s3, N_space = N_s3
  alpha_s3[s] ~ dnorm(alpha_s2[grid.lookup[s, 2]], sd = sigma_alpha_s3) # centred on s2 terms
  beta_s3[s]  ~ dnorm(beta_s2[grid.lookup[s, 2]], sd = sigma_beta_s3)
}
sigma_alpha_s3 ~ dunif(0,2)
sigma_beta_s3  ~ dunif(0,2)

# age
alpha_age[1] <- alpha0 # initialise first terms for RW
beta_age[1]  <- beta0
for(a in 2:N_age_groups){
  alpha_age[a] ~ dnorm(alpha_age[a-1], sd = sigma_alpha_age) # RW based on previous age group
  beta_age[a]  ~ dnorm(beta_age[a-1], sd = sigma_beta_age)
}
sigma_alpha_age ~ dunif(0,2)
sigma_beta_age ~ dunif(0,2)

# age-space interactions
for(a in 1:N_age_groups) {
  for(s in 1:N_space) {
    xi[a, s] ~ dnorm(alpha_age[a] + alpha_s3[s], sd = sigma_xi)
  }
}
sigma_xi ~ dunif(0,2)

# space-time random walk
for(s in 1:N_space){
  nu[s, 1] <- 0
  for(t in 2:N_year) {
    nu[s, t] ~ dnorm(nu[s, t-1] + beta_s3[s], sd = sigma_nu)
  }
}
sigma_nu ~ dunif(0,2)

# age-time random walk
for(a in 1:N_age_groups){
  gamma[a, 1] <- 0
  for(t in 2:N_year) {
    gamma[a, t] ~ dnorm(gamma[a, t-1] + beta_age[a], sd = sigma_gamma)
  }
}
sigma_gamma ~ dunif(0,2)

for(a in 1:N_age_groups) {
  for(s in 1:N_space) {
    for(t in 1:N_year) {
      latent_rate[a, s, t] <- xi[a, s] + nu[s, t] + gamma[a, t]
    }
  }
}

for (i in 1:N) {
	# y is number of deaths in that cell
	# mu is predicted death rate in that cell
	# n is the number of people in that cell
	y[i] ~ dbetabin(alpha[i], beta[i], n[i])
	alpha[i] <- mu[i] * theta
	beta[i] <- (1 - mu[i]) * theta
	logit(mu[i]) <- latent_rate[age[i], space[i], yr[i]]
}
theta ~ dexp(0.1)
}

My attempt at converting this into numpyro is below

def model(
    age_id,
    hier1_id,
    hier2_id,
    hier3_id,
    year_id,
    population,
    deaths=None
):
    # global terms
    alpha0 = numpyro.sample("alpha0", dist.Normal(0.0, 100.0))
    beta0 = numpyro.sample("beta0", dist.Normal(0.0, 100.0))

    # spatial hierarchy
    N_s1 = len(np.unique(hier1_id))
    sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0.0, 2.0))
    sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0.0, 2.0))
    with numpyro.plate("plate_s1", N_s1):
        z_s1 = numpyro.sample("z_s1", dist.Normal(0, 1).expand([2, N_s1]))

    N_s2 = len(np.unique(hier2_id))
    sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0.0, 2.0))
    sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0.0, 2.0))
    with numpyro.plate("plate_s2", N_s2):
        z_s2 = numpyro.sample("z_s2", dist.Normal(0, 1).expand([2, N_s2]))
    
    N_s3 = len(np.unique(hier3_id))
    sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0.0, 2.0))
    sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0.0, 2.0))
    with numpyro.plate("plate_s3", N_s3):
        z_s3 = numpyro.sample("z_s3", dist.Normal(0, 1).expand([2, N_s3]))

    alpha_s = z_s1[0, hier1_id] * sigma_alpha_s1 \
        + z_s2[0, hier2_id] * sigma_alpha_s2 \
        + z_s3[0, hier3_id] * sigma_alpha_s3
    beta_s = z_s1[1, hier1_id] * sigma_beta_s1 + \
        z_s2[1, hier2_id] * sigma_beta_s2 + \
        z_s3[1, hier3_id] * sigma_beta_s3

    # age
    N_age = len(np.unique(age_id))
    sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0.0, 2.0))
    sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0.0, 2.0))
    alpha_age = numpyro.sample(
        "alpha_age", dist.GaussianRandomWalk(scale=sigma_alpha_age, num_steps=N_age)
    )
    beta_age = numpyro.sample(
        "beta_age", dist.GaussianRandomWalk(scale=sigma_beta_age, num_steps=N_age)
    )
    
    # age-space interaction
    sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0.0, 2.0))
    xi = numpyro.sample("xi", dist.Normal(0, sigma_xi).expand([N_age, N_s3]))

    # space-time random walk
    N_t = len(np.unique(year_id))
    sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0.0, 2.0))
    nu = numpyro.sample(
        "nu",
        dist.GaussianRandomWalk(
            scale=sigma_nu,
            num_steps=N_t
        ).expand([N_s3])
    )
    
    # age-time random walk
    sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0.0, 2.0))
    gamma = numpyro.sample(
        "gamma",
        dist.GaussianRandomWalk(
            scale=sigma_gamma,
            num_steps=N_t
        ).expand([N_age])
    )
    
    # beta-binomial likelihood
    latent_rate = alpha0 + alpha_age[age_id] + alpha_s[hier3_id] \
        + (beta0 + beta_age[age_id] + beta_s[hier3_id]) * year_id \
        + xi[age_id, year_id] + nu[hier3_id, year_id] + gamma[age_id, year_id]
    mu = numpyro.deterministic("mu", expit(latent_rate))
    theta = numpyro.sample("theta", dist.Exponential(0.1))
    numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)

The model compiles and runs, suggesting I’ve got the shapes correct at least. Here are my questions:

  1. Inference in nimble was done using MCMC, where posterior geometry is not as important as for NUTS. When converting to numpyro, I made an effort to use a non-centred reparametrisation for the normal effects. However, I’m unsure whether the spatial effects are working as I intended. i.e, the s3 effects are centred on the s2 effects, which are centred on the s1 effects.
  2. Is .expand([N_age, N_s3]) the most efficient way to write the age-space interaction?
  3. Is expanding a GaussianRandomWalk with another dimension the correct way to write a random walk? For a pymc example, see the month_president_effect in this post: they would write nu = pm.GaussianRandomWalk("nu", sigma=sigma_nu, dims=("N_t", "N_s3"))
  4. In the BUGS model, I set the first element of random walks to be 0, so that the effect was identifiable. The first term of these GaussianRandomWalk terms are non-zero. How is the effect identifiable?

More importantly, although the model runs, all the iterations are divergent – so something is clearly wrong with the model! Finally, the full size model is big – age x space x time : 19 x 6791 x 18 – so any advice to make this model sample as efficiently as possible is welcome (there’s plenty of low-hanging fruit).

Big (but simple) model – a lot of help needed. I appreciate all of it. Cheers,

Theo

Some of my thoughts:

the s3 effects are centred on the s2 effects, which are centred on the s1 effects

Sounds reasonable to me. It is also better to reparameterize your GRW and also other Normal sites, e.g.

w_base = sample("w", dist.GaussianRandomWalk(scale=1., num_steps=N_t)
w = w_base * sigma

Variables that you can reparam are: alpha_age, beta_age, xi, nu, gamma

  • Using plate instead of expand will give more benefits like using visualization tools to inspect your model. Something like
with numpyro.plate("plate_s1", N_s1):
    z_s1 = numpyro.sample("z_s1", dist.Normal(0, 1).expand([2, N_s1]))

can be changed to

with numpyro.plate("plate_s1", N_s1):
    z_s1 = numpyro.sample("z_s1", dist.Normal(0, 1).expand([2]).to_event(1))
# then later access z_si[hier1_id, 0]

Is expanding a GaussianRandomWalk with another dimension the correct way to write a random walk?

As above, it is better to put your GRW inside a plate, if that’s what you want. Then your GRW will have batch shape in the left. The “random walk” dimension is the rightmost one. You can add some assertions to be more confident of the shape of your variables. Something like

space_plate = plate("space", ..., dim=-1)
with space_plate:
    sample("nu")
with plate("age", ..., dim=-2):
    with space_plate:
        sample("xi")
    sample("gamma")
latent_rate = ... + xi[..., None] + nu + gamma

IIUC, this will create an array with dimension age x space x year, then you can apply indexing after getting latent_rate. If that’s not what you want, you can squeeze the second-to-rightmost dimension of gamma to get the current behavior in your code.

set the first element of random walks to be 0

You can use time_steps - 1 and pad the GRW with 0.

Thanks for the speedy reply.

I have implemented the reparametrisation, .expand([2]).to_event(1)) and added padding (hopefully correctly using jnp.pad) to the alpha_age and beta_age parameters:

def model(
    age_id,
    hier1_id,
    hier2_id,
    hier3_id,
    year_id,
    population,
    deaths=None
):
    # global terms
    alpha0 = numpyro.sample("alpha0", dist.Normal(0., 100.))
    beta0 = numpyro.sample("beta0", dist.Normal(0., 100.))

    # spatial hierarchy
    N_s1 = len(np.unique(hier1_id))
    sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0., 2.))
    sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0., 2.))
    with numpyro.plate("plate_s1", N_s1):
        z_s1 = numpyro.sample("z_s1", dist.Normal(0, 1).expand([2]).to_event(1))

    N_s2 = len(np.unique(hier2_id))
    sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0., 2.))
    sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0., 2.))
    with numpyro.plate("plate_s2", N_s2):
        z_s2 = numpyro.sample("z_s2", dist.Normal(0., 1.).expand([2]).to_event(1))
    
    N_s3 = len(np.unique(hier3_id))
    sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0., 2.))
    sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0., 2.))
    with numpyro.plate("plate_s3", N_s3):
        z_s3 = numpyro.sample("z_s3", dist.Normal(0., 1.).expand([2]).to_event(1))

    alpha_s = z_s1[hier1_id, 0] * sigma_alpha_s1 \
        + z_s2[hier2_id, 0] * sigma_alpha_s2 \
        + z_s3[hier3_id, 0] * sigma_alpha_s3
    beta_s = z_s1[hier1_id, 1] * sigma_beta_s1 + \
        z_s2[hier2_id, 1] * sigma_beta_s2 + \
        z_s3[hier3_id, 1] * sigma_beta_s3

    # age
    N_age = len(np.unique(age_id))
    sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0., 2.))
    sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0., 2.))
    rw_alpha_age = numpyro.sample(
        "rw_alpha_age", dist.GaussianRandomWalk(scale=1., num_steps=(N_age-1))
    )
    rw_beta_age = numpyro.sample(
        "rw_beta_age", dist.GaussianRandomWalk(scale=1., num_steps=(N_age-1))
    )

    alpha_age = jnp.pad(rw_alpha_age * sigma_alpha_age, (1,0))
    beta_age = jnp.pad(rw_beta_age * sigma_beta_age, (1,0))
    
    # age-space interaction
    sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0., 2.))
    z_xi = numpyro.sample("z_xi", dist.Normal(0., 1.).expand([N_age, N_s3]))
    xi = z_xi * sigma_xi

    # space-time random walk
    N_t = len(np.unique(year_id))
    sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0., 2.))
    rw_nu = numpyro.sample(
        "rw_nu",
        dist.GaussianRandomWalk(
            scale=1.,
            num_steps=N_t
        ).expand([N_s3])
    )
    nu = rw_nu * sigma_nu
    
    # age-time random walk
    sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0.0, 2.0))
    rw_gamma = numpyro.sample(
        "rw_gamma",
        dist.GaussianRandomWalk(
            scale=1.,
            num_steps=N_t
        ).expand([N_age])
    )
    gamma = rw_gamma * sigma_gamma
    
    # beta-binomial likelihood
    latent_rate = alpha0 + alpha_age[age_id] + alpha_s[hier3_id] \
        + (beta0 + beta_age[age_id] + beta_s[hier3_id]) * year_id \
        + xi[age_id, year_id] + nu[hier3_id, year_id] + gamma[age_id, year_id]
    mu = numpyro.deterministic("mu", expit(latent_rate))
    theta = numpyro.sample("theta", dist.Exponential(0.1))
    numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)

However, I’m still a bit confused about the nested plates for the interaction and random walk terms. Do you mean something like:

space_plate = numpyro.plate("plate_space", N_s3)
age_plate = numpyro.plate("plate_age", N_age)

with space_plate:
    z_s3 = numpyro.sample("z_s3", dist.Normal(0., 1.).expand([2]).to_event(1))

    rw_nu = numpyro.sample(
        "rw_nu",
        dist.GaussianRandomWalk(
            scale=1.,
            num_steps=N_t
        )
    )

nu = rw_nu * sigma_nu

with age_plate:
    rw_gamma = numpyro.sample(
        "rw_gamma",
        dist.GaussianRandomWalk(
            scale=1.,
            num_steps=N_t
        )
    )
    with space_plate:
        z_xi = numpyro.sample("z_xi", dist.Normal(0., 1.)

gamma = rw_gamma * sigma_gamma
xi = z_xi * sigma_xi

My main questions about the template code you wrote:

  1. I’m unsure how dim=-1 and dim=-2 are working, mostly due to my unfamiliarity coming from R. Please could you explain this further and why they are necessary
  2. I don’t understand why nu and gamma would have the dimension age x space x year and not just space x year and age x year, because they are only nested in 2 plates. Equally, I don’t understand why xi[..., None] is required rather than just xi[age_id, year_id]
  3. How do I pad a random walk inside a plate such that (e.g.) gamma[age_id, 0] are all 0 for identifiability

Cheers and looking forward to visualising all the plates,

Theo

why nu and gamma would have the dimension age x space x year and not just space x year and age x year

Here year is an event dimension, space plate has dim=-1 and age plate has dim=-2 so nu and gamma will have shape space x year and age x 1 x year. So if you add nu and gamma, you will get an array with shape age x space x year. Essentially, we use plate notation to declare batch dimensions of a variable. In R, that sum will be equivalent to

for(a in 1:N_age_groups) {
  for(s in 1:N_space) {
    for(t in 1:N_year) {
      latent_rate[a, s, t] <- nu[s, t] + gamma[a, t]
    }
  }
}

Typically, instead of using loops, we just think that nu belongs to age plate, gamma belongs to space plate and latent rate is the sum of them (I ignore the xi terms for simplicity - in my last comment, I used xi[..., None] to make xi has shape age x space x year, hence the sum xi[..., None] + nu + gamma is equivalent to the original R code in your first post. You can use the visualization tool in my last comment to interpret your model (note that you need to use plate rather than .expand() to help the tool render correctly). Btw, I would suggest calculating latent_rate as in R code first, then indexing as in R code

# NOTE: alpha_age[:, None] to make it have shape age x 1
# 1, which will be broadcasted, represents year dimension
xi = xi + (alpha0 + alpha_age[:, None] + alpha_s3)
nu = nu + beta_s3[:, None] * jnp.arange(N_year)
gamma = gamma + beta0 + beta_age[:, None, None] * jnp.arange(N_year)
latent_rate = xi[..., None] + nu + gamma
mu_logit = latent_rate[age_id, hier3_id, year_id]

How do I pad a random walk inside a plate such that (e.g.) gamma

One way is

with plate("age"):
    gamma = sample("gamma", GRW(..., num_steps=T-1))
    gamma = jnp.pad(gamma, (0, 0, 1, 0), constant_values=0.)

It has been a while for me to do modeling. Here is my attempt (I haven’t tested the code yet - just wanna give you a template for common practices using NumPyro). As you will see, the code is very similar to R code. I just split random walks into drift terms to make things easier to reparam.

reparam_config = {
    k: LocScaleReparam(0) for k in
    ["alpha_s1", "beta_s1", "alpha_s2", "beta_s2", "alpha_s3", "beta_s3",
     "alpha_age_drift", "xi", "nu_drift", "gamma_drift"]}

@numpyro.handlers.reparam(config=reparam_config)
def model(...):
  # plates
  space_plate = numpyro.plate("space", N_s3, dim=-3)
  age_plate = numpyro.plate("age_groups", N_age, dim=-2)
  year_plate = numpyro.plate("year", N_t - 1, dim=-1)

  # global terms
  sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0, 2))
  sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0, 2))
  sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0, 2))
  sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0, 2))
  sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0, 2))
  sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0, 2))
  sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0, 2))
  sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0, 2))
  sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0, 2))
  sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0, 2))
  sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0, 2))
  theta = numpyro.sample("theta", dist.Exponential(1))

  # spatial hierarchy
  with numpyro.plate("s1", N_s1, dim=-3):
    alpha_s1 = numpyro.sample("alpha_s1", dist.Normal(0, sigma_alpha_s1))
    beta_s1 = numpyro.sample("beta_s1", dist.Normal(0, sigma_beta_s1))
  with numpyro.plate("s2", N_s2, dim=-3):
    # FIXME: hier1_id should have length N_s2, hier2_id should have length N_s3
    alpha_s2 = numpyro.sample("alpha_s2", dist.Normal(alpha_s1[hier1_id], sigma_alpha_s2))
    beta_s2 = numpyro.sample("beta_s2", dist.Normal(beta_s1[hier1_id], sigma_beta_s2))
  with space_plate:
    alpha_s3 = numpyro.sample("alpha_s3", dist.Normal(alpha_s2[hier2_id], sigma_alpha_s3))
    beta_s3 = numpyro.sample("beta_s3", dist.Normal(beta_s2[hier2_id], sigma_beta_s3))

  # age
  with age_plate:
    alpha_age_drift_scale = jnp.pad(
        jnp.broadcast_to(sigma_alpha_age, N_age - 1), (1, 0), constant_values=100.)[:, None]
    alpha_age_drift = numpyro.sample("alpha_age_drift", dist.Normal(0, alpha_age_drift_scale))
    alpha_age = jnp.cumsum(alpha_age_drift, -2)
    beta_age_drift_scale = jnp.pad(
        jnp.broadcast_to(sigma_beta_age, N_age - 1), [(1, 0)], constant_values=100.)[:, None]
    beta_age_drift = numpyro.sample("beta_age_drift", dist.Normal(0, beta_age_drift_scale))
    beta_age = jnp.cumsum(beta_age_drift, -2)

  # age-space interactions
  with age_plate, space_plate:
    xi = numpyro.sample("xi", dist.Normal(alpha_age + alpha_s3, sigma_xi))

  # space-time random walk
  with space_plate, year_plate:
    nu_drift = numpyro.sample("nu_drift", dist.Normal(beta_s3, sigma_nu))
    nu = jnp.pad(jnp.cumsum(nu_drift, -1), [(0, 0), (0, 0), (1, 0)])

  # age-time random walk
  with age_plate, year_plate:
    gamma_drift = numpyro.sample("gamma_drift", dist.Normal(beta_age, sigma_gamma))
    gamma = jnp.pad(jnp.cumsum(gamma_drift, -1), [(0, 0), (1, 0)])

  latent_rate = xi + nu + gamma
  # likelihood
  with numpyro.plate("N", N):
    mu_logit = latent_rate[age_id, hier3_id, year_id]
    mu = numpyro.deterministic("mu", expit(mu_logit))
    numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)

For completeness and for anyone reading this thread, here is a working version of the initial model I had in my head.

def model(
    age_id,
    hier1_id,
    hier2_id,
    hier3_id,
    year_id,
    population,
    deaths=None
):

    N_s1 = len(np.unique(hier1_id))
    N_s2 = len(np.unique(hier2_id))
    N_s3 = len(np.unique(hier3_id))
    N_age = len(np.unique(age_id))
    N_t = len(np.unique(year_id))
    
    # plates
    space_plate = numpyro.plate("space", N_s3, dim=-1)
    age_plate = numpyro.plate("age_groups", N_age, dim=-2)

    # hyperparameters
    alpha0 = numpyro.sample("alpha0", dist.Normal(0., 100.))
    beta0 = numpyro.sample("beta0", dist.Normal(0., 100.))
    sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0., 2.))
    sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0., 2.))
    sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0., 2.))
    sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0., 2.))
    sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0., 2.))
    sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0., 2.))
    sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0., 2.))
    sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0., 2.))
    sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0., 2.))
    sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0., 2.))
    sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0.0, 2.0))
    theta = numpyro.sample("theta", dist.Exponential(0.1))

    # spatial hierarchy
    with numpyro.plate("plate_s1", N_s1):
        z_s1 = numpyro.sample("z_s1", dist.Normal(0, 1).expand([2]).to_event(1))
    
    with numpyro.plate("plate_s2", N_s2):
        z_s2 = numpyro.sample("z_s2", dist.Normal(0., 1.).expand([2]).to_event(1))
    
    with space_plate:
        z_s3 = numpyro.sample("z_s3", dist.Normal(0., 1.).expand([2]).to_event(1))
        # space-time random walk
        rw_nu = numpyro.sample(
            "rw_nu",
            dist.GaussianRandomWalk(scale=1., num_steps=(N_t-1))
        )
        rw_nu = jnp.pad(rw_nu, ((0, 0), (1, 0)))

    nu = rw_nu * sigma_nu

    alpha_s = z_s1[hier1_id, 0] * sigma_alpha_s1 \
        + z_s2[hier2_id, 0] * sigma_alpha_s2 \
        + z_s3[hier3_id, 0] * sigma_alpha_s3
    beta_s = z_s1[hier1_id, 1] * sigma_beta_s1 + \
        z_s2[hier2_id, 1] * sigma_beta_s2 + \
        z_s3[hier3_id, 1] * sigma_beta_s3

    # age
    rw_alpha_age = numpyro.sample(
        "rw_alpha_age", dist.GaussianRandomWalk(scale=1., num_steps=(N_age-1))
    )
    rw_beta_age = numpyro.sample(
        "rw_beta_age", dist.GaussianRandomWalk(scale=1., num_steps=(N_age-1))
    )

    alpha_age = jnp.pad(rw_alpha_age * sigma_alpha_age, (1,0))
    beta_age = jnp.pad(rw_beta_age * sigma_beta_age, (1,0))
    
    with age_plate:
        # age-time random walk
        rw_gamma = numpyro.sample(
            "rw_gamma",
            dist.GaussianRandomWalk(scale=1., num_steps=(N_t-1))
        )
        rw_gamma = jnp.squeeze(rw_gamma, axis=1)
        rw_gamma = jnp.pad(rw_gamma, ((0, 0), (1, 0)))
        # age-space interaction
        with space_plate:
            z_xi = numpyro.sample("z_xi", dist.Normal(0., 1.))
    xi = z_xi * sigma_xi
    
    gamma = rw_gamma * sigma_gamma
    
    # beta-binomial likelihood
    latent_rate = alpha0 + alpha_age[age_id] + alpha_s[hier3_id] \
        + (beta0 + beta_age[age_id] + beta_s[hier3_id]) * year_id \
        + xi[age_id, year_id] + nu[hier3_id, year_id] + gamma[age_id, year_id]
    mu = numpyro.deterministic("mu", expit(latent_rate))
    numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)

However, I like the way you’ve written the model. Thank you! The @numpyro.handlers.reparam(config=reparam_config) is a really handy trick to get around manually writing non-centred parametrisations. Also, I see how you’ve written it more like the initial R/BUGS code I gave you: each effect has the correct, minimal shape, and then you call that in a final numpyro.plate("N", N) (data plate). This is better than my version of the model above, where effects like alpha_s and beta_s unnecessarily have the same length as the data.

So I’m going to get your version of the model working. But before I do:

  1. Where did the global intercept and global slope hyperparameters, alpha0 and beta0 go?
  2. What is the line alpha_age_drift_scale = jnp.pad( jnp.broadcast_to(sigma_alpha_age, N_age - 1), (1, 0), constant_values=100.)[:, None] doing?
  3. Why do the space-time random walk and age-time random walk need different padding? ((0, 0, 0, 0, 1, 0) vs (0, 0, 1, 0))
  4. How does numpyro treat the likelihood differently between the two models. i.e. what’s the difference between putting numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths) in a plate or not in a plate?

I’ll work on this tomorrow and get it running.

Oh it is great to hear that the inference works now (I was worrying that you might not get good result as BUGS because BUGS is using fancier algorithm).

  1. Where did the global intercept and global slope hyperparameters, alpha0 and beta0 go?
  2. What is the line alpha_age_drift_scale = jnp.pad( jnp.broadcast_to(sigma_alpha_age, N_age - 1), (1, 0), constant_values=100.)[:, None] doing?

alpha0, beta0 are the initial values of alpha_age, beta_age. I can split it out. Essentially, they are alpha_age_drift[0] and beta_age_drift[0]. They follow Normal distribution with scale=100 - i.e. alpha/beta_age_drift_scale[0] = 100. The reason for ...[:, None] is to add a singleton event dimension to alpha_age_drift (because age_plate.dim = -2).

Why do the space-time random walk and age-time random walk need different padding?

Those codes just padded the first element of the last dimension (i.e. the year dimension) of nu and gamma by 0. Because nu_drift has rank 3: space x 1 x year, likewise gamma_drift has rank 2: age x year. (I wish numpy support something like np.pad(x, (1, 0), axis=-1) to pad the first element of the last dimension of x by 0, but that feature is not available). If you are unsure about dimensions, you can use in both cases

x = jnp.pad(x, [(0, 0) * (x.ndim - 1)] + [(1, 0)])

(note that I just fixed the usage of jnp.pad in my code - previously I used a different pattern for pad_width)

in a plate or not in a plate?

They should be treated similarly, but it is a good practice to declare plates explicitly (so the code is more readable and looks corresponding 1-1 to its graphical plot).

An aside: Some context on the model and why I’m now using numpyro

So, the model I gave initially was written in BUGS syntax, but the inference itself was in nimble. It’s a nice PPL by a nice team. nimble is like a “pick n mix” for samplers. You can choose for any parameters to have any sampler (that they offer), which gives the user a lot of control. The defaults are to use conjugate relationships where possible, and RW samplers where it isn’t. They don’t have NUTS yet though.

nimble runs fast for this model and it scales nicely. When I say the numpyro inference works, it runs much more slowly than nimble (per iteration) (6 minutes for 100 iterations for nimble, ~30 minutes for 100 iterations numpyro; both excluding compile times), but I’m comparing adaptive NUTS to MCMC, so really I should compare ESS/time. But I’ll do that once I’ve got the new (your) version of the model running.

Also, I’m currently running on simulated data (death records are identifiable so I can’t keep the real data on my laptop, but I can still build a model on simulated data). The simulated data has dimensions age-19, space-113, time-18 (real data age-19, space-6791, time-18). It takes nimble about 10 days to get enough samples for the full size model. Hopefully numpyro will scale up to the full size model, and I should be able to use a GPU to help that. The NUTS should also make it sample more efficiently. Also, we want to try some funky models with neural networks inside them, and numpyro’s flax/haiku support is tasty.

Problems on the numpyro model

I’m getting an error when I implement the model below

@numpyro.handlers.reparam(config=reparam_config)
def model(
    space,
    age,
    time,
    lookup1,
    lookup2,
    population,
    deaths=None
):

    N_s1 = len(np.unique(lookup1))
    N_s2 = len(np.unique(lookup2))
    N_s3 = len(np.unique(space))
    N_age = len(np.unique(age))
    N_t = len(np.unique(time))
    N = len(population)
    
    # plates
    space_plate = numpyro.plate("space", N_s3, dim=-3)
    age_plate = numpyro.plate("age_groups", N_age, dim=-2)
    year_plate = numpyro.plate("year", N_t - 1, dim=-1)

    # hyperparameters
    sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0., 2.))
    sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0., 2.))
    sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0., 2.))
    sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0., 2.))
    sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0., 2.))
    sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0., 2.))
    sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0., 2.))
    sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0., 2.))
    sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0., 2.))
    sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0., 2.))
    sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0.0, 2.0))
    theta = numpyro.sample("theta", dist.Exponential(0.1))

    # spatial hierarchy
    with numpyro.plate("s1", N_s1, dim=-3):
        alpha_s1 = numpyro.sample("alpha_s1", dist.Normal(0, sigma_alpha_s1))
        beta_s1 = numpyro.sample("beta_s1", dist.Normal(0, sigma_beta_s1))
    with numpyro.plate("s2", N_s2, dim=-3):
        alpha_s2 = numpyro.sample("alpha_s2", dist.Normal(alpha_s1[lookup1], sigma_alpha_s2))
        beta_s2 = numpyro.sample("beta_s2", dist.Normal(beta_s1[lookup1], sigma_beta_s2))
    with space_plate:
        alpha_s3 = numpyro.sample("alpha_s3", dist.Normal(alpha_s2[lookup2], sigma_alpha_s3))
        beta_s3 = numpyro.sample("beta_s3", dist.Normal(beta_s2[lookup2], sigma_beta_s3))

    # age
    with age_plate:
        alpha_age_drift_scale = jnp.pad(
            jnp.broadcast_to(
                sigma_alpha_age,
                N_age - 1
            ),
            (1, 0),
            constant_values=100. # pad so first term is alpha0, the global intercept with prior N(0, 100)
        )[:, jnp.newaxis]
        alpha_age_drift = numpyro.sample("alpha_age_drift", dist.Normal(0, alpha_age_drift_scale))
        alpha_age = jnp.cumsum(alpha_age_drift, -2)

        beta_age_drift_scale = jnp.pad(
            jnp.broadcast_to(
                sigma_beta_age,
                N_age - 1
            ),
            (1, 0),
            constant_values=100.
        )[:, jnp.newaxis]
        beta_age_drift = numpyro.sample("beta_age_drift", dist.Normal(0, beta_age_drift_scale))
        beta_age = jnp.cumsum(beta_age_drift, -2)
    
    # age-space interactions
    with age_plate, space_plate:
        xi = numpyro.sample("xi", dist.Normal(alpha_age + alpha_s3, sigma_xi))
    
    # space-time random walk
    with space_plate, year_plate:
        nu_drift = numpyro.sample("nu_drift", dist.Normal(beta_s3, sigma_nu))
        nu = jnp.pad(jnp.cumsum(nu_drift, -1), [(0, 0), (0, 0), (1, 0)])

    # age-time random walk
    with age_plate, year_plate:
        gamma_drift = numpyro.sample("gamma_drift", dist.Normal(beta_age, sigma_gamma))
        gamma = jnp.pad(jnp.cumsum(gamma_drift, -1), [(0, 0), (1, 0)])
    
    # likelihood
    latent_rate = xi + nu + gamma
    print(alpha_s3.shape)
    print(alpha_age.shape)
    print(xi.shape)
    print(nu.shape)
    print(gamma.shape)
    print(latent_rate.shape)
    with numpyro.plate("N", N):
        mu_logit = latent_rate[space, age, time]
        mu = numpyro.deterministic("mu", expit(mu_logit))
        numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)

This is the same as you wrote, except latent_rate[space, age, time] has changed to match the plate dims.

I get the error ValueError: Dirichlet distribution got invalid concentration parameter. I don’t think that’s a shape problem, but, in case you wanted to check the parameters had the right shapes, the outputs of the print statements are

(113, 1, 1)
(19, 1)
(113, 19, 1)
(113, 1, 18)
(19, 18)
(113, 19, 18)

Let me know what you think the problem is

I think we are getting some invalid values for parameters of BetaBinomial. Could you print mu.shape and look at mu * theta and (1 - mu) * theta to see if they are valid? The shapes that you printed out look reasonable to me.

I added

print(mu.shape)
print((mu * theta).view())
print(((1 - mu) * theta).view())

which gives

(38646,)
[3.1522758e-14 2.7437749e-14 2.1988657e-13 ... 8.0548018e-01 8.0548018e-01
 8.0548018e-01]
[0.8054802 0.8054802 0.8054802 ... 0.        0.        0.       ]
(38646,)
Traced<ConcreteArray([3.1522758e-14 2.7437749e-14 2.1988657e-13 ... 8.0548018e-01 8.0548018e-01
 8.0548018e-01])>with<JVPTrace(level=2/0)>
  with primal = DeviceArray([3.1522758e-14, 2.7437749e-14, 2.1988657e-13, ...,
                             8.0548018e-01, 8.0548018e-01, 8.0548018e-01], dtype=float32)
       tangent = Traced<ShapedArray(float32[38646]):JaxprTrace(level=1/0)>
Traced<ConcreteArray([0.8054802 0.8054802 0.8054802 ... 0.        0.        0.       ])>with<JVPTrace(level=2/0)>
  with primal = DeviceArray([0.8054802, 0.8054802, 0.8054802, ..., 0.       , 0.       ,
                             0.       ], dtype=float32)
       tangent = Traced<ShapedArray(float32[38646]):JaxprTrace(level=1/0)>
(38646,)
Traced<ShapedArray(float32[38646])>with<JVPTrace(level=3/0)>
  with primal = Traced<ShapedArray(float32[38646])>with<DynamicJaxprTrace(level=1/0)>
       tangent = Traced<ShapedArray(float32[38646]):JaxprTrace(level=2/0)>
Traced<ShapedArray(float32[38646])>with<JVPTrace(level=3/0)>
  with primal = Traced<ShapedArray(float32[38646])>with<DynamicJaxprTrace(level=1/0)>
       tangent = Traced<ShapedArray(float32[38646]):JaxprTrace(level=2/0)>
...




ValueError: Dirichlet distribution got invalid concentration parameter

I didn’t get this error on the other model (if that helps)

It seems that 1 - mu is getting value 0, i.e. latent rate is pretty large. You might want to check why either xi, mu, or gamma is getting large values.

no idea about the details of your model but it looks like your latent dimension is pretty high. if you are able to compute conditional posterior distributions for some of the variables (especially the ones of highest dimension) you might get better performance from HMCGibbs. hmc doesn’t necessarily work that well as the latent dimension gets very large

@martinjankowiak The current data I’m using has dimensions 19 x 113 x 18 (age x space x time), and NUTS runs for the earlier version of the model, so I expect it to run with this newer version. On the full data, the dimensionality goes up to 19 x 6791 x18 so the latent dimension might cause some issues. If that’s the case, I’ll try out HMCGibbs – although writing custom Gibbs samplers for this model might be tricky. Unless you know a quick hack for writing custom Gibbs samplers?


@fehiepsi It looks like both nu and gamma can get quite large, with values (depending on the seed) reaching ~ +/- 100 and getting more positive/negative as we move through the time dimension. This behaviour is expected as these effects are centred on slopes, and the slope effect accumulates over time.

I think the model is acting as expected, but you might think otherwise. And I’m a bit confused why this wasn’t an issue for the old version of the model. Is there a fix for this so the likelihood doesn’t run into trouble?

(Some reproducible code here if you’d like to take a look. Thanks again for your help so far)

Cheers both

why this wasn’t an issue for the old version of the model

It is hard to say. I think the main reason is we initialize the latent variables with different values (due to the differences in the way to parameterize the model).

Looking at the code, I think you might want to add beta_age_drift to the reparam config. In addition, you might want to call numpyro.enable_x64() at the top of your program to get better precisions. Hopefully, it will help you avoid the 0 values. Otherwise, you can clamp mu to make it a positive number mu = jnp.clip(mu, a_min=jnp.finfo(mu.dtype).tiny).

You can inspect the shapes to make sure that things look as expected:

with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs)
    # args, kwargs are your model inputs, like age, space, time, lookups,...
print(numpyro.util.format_shapes(trace))

@theo does “runs” means the code executes or that you get plausible posterior samples, small r_hats, etc?

there isn’t much in the way of machinery for automating gibbs steps, so you’d have to derive the update equation yourself. might be worth doing though for any high-dimensional latent variables (e.g. any high-dimensional normally distributed variables)

Hi again, happy new year both.

@fehiepsi So a couple of things worked to get the code running (executing). Firstly, I tried using different initialisation strategies. The code had ValueError: Dirichlet distribution got invalid concentration parameter for the default init_to_uniform and init_to_median, but the code executed with init_to_feasible. I then tried numpyro.enable_x64() with the default initialisation and the code ran. As the first option worked, I didn’t clamp mu to be positive.

However, the results did not match up. The init_to_feasible strategy produce no divergent iterations, but was much slower at ~15s/it. The numpyro.enable_x64() strategy was much faster at ~2s/it but all the iterations were divergent. When I ran this with the nimble packages, I initialised to specific values.

(@martinjankowiak too) Although I haven’t focussed on plausible posterior samples so far because I’m currently working on simulated data rather than the real data (on a secure system, so difficult to use for model development), this performance on simulated data is concerning. I could run the model for new simulated data using this exact model (i.e setting the hyperpriors to a value – if there is a quick way of doing this let me know) so we can at least see if numpyro can fit this model to an idealised situation. Otherwise, do you have any suggestions which might improve performance on the current version of simulated data?

Finally, if we get this converging nicely and it would be of interest to you, I’d be happy to write this up in the numpyro style and contribute it as an example. It could be useful as a real-world problem, a problem over space and time, an example of @numpyro.handlers.reparam, but we can discuss this on github if you’d like me to open a PR.

haven’t looked at the detailed structure of your model but presumably you can integrate out some of the normal latent variables by hand. this would reduce the dimensionality of the latent space and would probably lead to better behavior. alternatively, as mentioned before, you could use gibbs updates w.r.t. the normal latent variables

This is great! We’re looking for it. :slight_smile:

initialized to specific values

I guess you can use init_to_value strategy for this.

Hey, I’m really struggling to get this to converge. All the iterations are divergent.

I reduced to a 2 tier hierarchy (so just an s1 and an s2), simulated data by fixing the hyperparameters to exact values, and I cannot seem to get the model to converge. I have an init_to_value model running now but it is excruciatingly slow (code here if anyone is interested).

Nevertheless, I’ll write this up and contribute it on numpyor github. Would you like to work through the issues here on the forum or shall i move everything over to a PR on github and we can work on it there?

You might try SVI inference instead. It might give you more reliable result.

Thanks for offering the example. I’m looking forward to reading it.