Hi there,
I’m translating a model I wrote in BUGS into numpyro and I need some help making it correct and efficient. The model produces mortality estimates for different age groups over space and time. The model uses the following effects to form a logit-transformed death rate:
- global intercept and slope (
alpha0
,beta0
) - spatial intercepts and slopes (
alpha_s
,beta_s
). These are produced from a three tier nested administrative hierarchy. Thes1
terms represent the largest geography and are made up ofs2
units, which are made up ofs3
units – think: state → county → tract. Thes2
terms are centred on thes1
terms, and thes3
terms are centred on thes2
terms. - age terms (
alpha_age
,beta_age
). Random walk over ages to produce a curve representing higher infant and higher older age mortality - age-space interaction (
xi
) allowing different spatial units to have different age curves - space/age random walks (
nu
andgamma
) allowing different spatial units/age groups to have non-linear temporal patters
We map the death rate to death counts using a beta-binomial likelihood with known populations (i.e number of trials).
More detail on the model can be found in the paper. The model inference was performed in nimble, which compiles the model in C++ and runs MCMC. The BUGS model is below
model{
# global terms
alpha0 ~ dnorm(0, 0.00001)
beta0 ~ dnorm(0, 0.00001)
# spatial hierarchy
for(s1 in 1:N_s1){
alpha_s1[s1] ~ dnorm(0, sd = sigma_alpha_s1)
beta_s1[s1] ~ dnorm(0, sd = sigma_beta_s1)
}
sigma_alpha_s1 ~ dunif(0,2)
sigma_beta_s1 ~ dunif(0,2)
for(s2 in 1:N_s2){
alpha_s2[s2] ~ dnorm(alpha_s1[grid.lookup.s2[s2, 2]], sd = sigma_alpha_s2) # centred on s1 terms
beta_s2[s2] ~ dnorm(beta_s1[grid.lookup.s2[s2, 2]], sd = sigma_beta_s2)
}
sigma_alpha_s2 ~ dunif(0,2)
sigma_beta_s2 ~ dunif(0,2)
for(s in 1:N_space){ # s = s3, N_space = N_s3
alpha_s3[s] ~ dnorm(alpha_s2[grid.lookup[s, 2]], sd = sigma_alpha_s3) # centred on s2 terms
beta_s3[s] ~ dnorm(beta_s2[grid.lookup[s, 2]], sd = sigma_beta_s3)
}
sigma_alpha_s3 ~ dunif(0,2)
sigma_beta_s3 ~ dunif(0,2)
# age
alpha_age[1] <- alpha0 # initialise first terms for RW
beta_age[1] <- beta0
for(a in 2:N_age_groups){
alpha_age[a] ~ dnorm(alpha_age[a-1], sd = sigma_alpha_age) # RW based on previous age group
beta_age[a] ~ dnorm(beta_age[a-1], sd = sigma_beta_age)
}
sigma_alpha_age ~ dunif(0,2)
sigma_beta_age ~ dunif(0,2)
# age-space interactions
for(a in 1:N_age_groups) {
for(s in 1:N_space) {
xi[a, s] ~ dnorm(alpha_age[a] + alpha_s3[s], sd = sigma_xi)
}
}
sigma_xi ~ dunif(0,2)
# space-time random walk
for(s in 1:N_space){
nu[s, 1] <- 0
for(t in 2:N_year) {
nu[s, t] ~ dnorm(nu[s, t-1] + beta_s3[s], sd = sigma_nu)
}
}
sigma_nu ~ dunif(0,2)
# age-time random walk
for(a in 1:N_age_groups){
gamma[a, 1] <- 0
for(t in 2:N_year) {
gamma[a, t] ~ dnorm(gamma[a, t-1] + beta_age[a], sd = sigma_gamma)
}
}
sigma_gamma ~ dunif(0,2)
for(a in 1:N_age_groups) {
for(s in 1:N_space) {
for(t in 1:N_year) {
latent_rate[a, s, t] <- xi[a, s] + nu[s, t] + gamma[a, t]
}
}
}
for (i in 1:N) {
# y is number of deaths in that cell
# mu is predicted death rate in that cell
# n is the number of people in that cell
y[i] ~ dbetabin(alpha[i], beta[i], n[i])
alpha[i] <- mu[i] * theta
beta[i] <- (1 - mu[i]) * theta
logit(mu[i]) <- latent_rate[age[i], space[i], yr[i]]
}
theta ~ dexp(0.1)
}
My attempt at converting this into numpyro is below
def model(
age_id,
hier1_id,
hier2_id,
hier3_id,
year_id,
population,
deaths=None
):
# global terms
alpha0 = numpyro.sample("alpha0", dist.Normal(0.0, 100.0))
beta0 = numpyro.sample("beta0", dist.Normal(0.0, 100.0))
# spatial hierarchy
N_s1 = len(np.unique(hier1_id))
sigma_alpha_s1 = numpyro.sample("sigma_alpha_s1", dist.Uniform(0.0, 2.0))
sigma_beta_s1 = numpyro.sample("sigma_beta_s1", dist.Uniform(0.0, 2.0))
with numpyro.plate("plate_s1", N_s1):
z_s1 = numpyro.sample("z_s1", dist.Normal(0, 1).expand([2, N_s1]))
N_s2 = len(np.unique(hier2_id))
sigma_alpha_s2 = numpyro.sample("sigma_alpha_s2", dist.Uniform(0.0, 2.0))
sigma_beta_s2 = numpyro.sample("sigma_beta_s2", dist.Uniform(0.0, 2.0))
with numpyro.plate("plate_s2", N_s2):
z_s2 = numpyro.sample("z_s2", dist.Normal(0, 1).expand([2, N_s2]))
N_s3 = len(np.unique(hier3_id))
sigma_alpha_s3 = numpyro.sample("sigma_alpha_s3", dist.Uniform(0.0, 2.0))
sigma_beta_s3 = numpyro.sample("sigma_beta_s3", dist.Uniform(0.0, 2.0))
with numpyro.plate("plate_s3", N_s3):
z_s3 = numpyro.sample("z_s3", dist.Normal(0, 1).expand([2, N_s3]))
alpha_s = z_s1[0, hier1_id] * sigma_alpha_s1 \
+ z_s2[0, hier2_id] * sigma_alpha_s2 \
+ z_s3[0, hier3_id] * sigma_alpha_s3
beta_s = z_s1[1, hier1_id] * sigma_beta_s1 + \
z_s2[1, hier2_id] * sigma_beta_s2 + \
z_s3[1, hier3_id] * sigma_beta_s3
# age
N_age = len(np.unique(age_id))
sigma_alpha_age = numpyro.sample("sigma_alpha_age", dist.Uniform(0.0, 2.0))
sigma_beta_age = numpyro.sample("sigma_beta_age", dist.Uniform(0.0, 2.0))
alpha_age = numpyro.sample(
"alpha_age", dist.GaussianRandomWalk(scale=sigma_alpha_age, num_steps=N_age)
)
beta_age = numpyro.sample(
"beta_age", dist.GaussianRandomWalk(scale=sigma_beta_age, num_steps=N_age)
)
# age-space interaction
sigma_xi = numpyro.sample("sigma_xi", dist.Uniform(0.0, 2.0))
xi = numpyro.sample("xi", dist.Normal(0, sigma_xi).expand([N_age, N_s3]))
# space-time random walk
N_t = len(np.unique(year_id))
sigma_nu = numpyro.sample("sigma_nu", dist.Uniform(0.0, 2.0))
nu = numpyro.sample(
"nu",
dist.GaussianRandomWalk(
scale=sigma_nu,
num_steps=N_t
).expand([N_s3])
)
# age-time random walk
sigma_gamma = numpyro.sample("sigma_gamma", dist.Uniform(0.0, 2.0))
gamma = numpyro.sample(
"gamma",
dist.GaussianRandomWalk(
scale=sigma_gamma,
num_steps=N_t
).expand([N_age])
)
# beta-binomial likelihood
latent_rate = alpha0 + alpha_age[age_id] + alpha_s[hier3_id] \
+ (beta0 + beta_age[age_id] + beta_s[hier3_id]) * year_id \
+ xi[age_id, year_id] + nu[hier3_id, year_id] + gamma[age_id, year_id]
mu = numpyro.deterministic("mu", expit(latent_rate))
theta = numpyro.sample("theta", dist.Exponential(0.1))
numpyro.sample("deaths", dist.BetaBinomial(mu * theta, (1 - mu) * theta, population), obs=deaths)
The model compiles and runs, suggesting I’ve got the shapes correct at least. Here are my questions:
- Inference in nimble was done using MCMC, where posterior geometry is not as important as for NUTS. When converting to numpyro, I made an effort to use a non-centred reparametrisation for the normal effects. However, I’m unsure whether the spatial effects are working as I intended. i.e, the
s3
effects are centred on thes2
effects, which are centred on thes1
effects. - Is
.expand([N_age, N_s3])
the most efficient way to write the age-space interaction? - Is expanding a
GaussianRandomWalk
with another dimension the correct way to write a random walk? For a pymc example, see themonth_president_effect
in this post: they would writenu = pm.GaussianRandomWalk("nu", sigma=sigma_nu, dims=("N_t", "N_s3"))
- In the BUGS model, I set the first element of random walks to be 0, so that the effect was identifiable. The first term of these
GaussianRandomWalk
terms are non-zero. How is the effect identifiable?
More importantly, although the model runs, all the iterations are divergent – so something is clearly wrong with the model! Finally, the full size model is big – age x space x time : 19 x 6791 x 18 – so any advice to make this model sample as efficiently as possible is welcome (there’s plenty of low-hanging fruit).
Big (but simple) model – a lot of help needed. I appreciate all of it. Cheers,
Theo