Hi,
I am trying to fit a hierarchical model for insect emergence using NUTS. I have two layers, the first computes the phenological patterns using differences between normal CDF values, the second then uses those values, along with abundance and overwintering parameters, to compute the emergence means, which are then fed into a Poisson.
I have tested both of the two layers in individual models, and NUTS is able to sample well and recover the true parameters. However, when I combine those pieces of code into one model, the sampling basically stops: my ESS values are 2, my rhat values are over 3 million, and the model can no longer reliably recover the parameters.
Here is my code - if anyone has any pointers about what could be causing functional layers to fail this catastrophically when put in a hierarchical model, it would be greatly appreciated!
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpy as np
import arviz as az
from numpyro.infer import init_to_sample
import pandas as pd
import matplotlib.pyplot as plt
import json
# Load data
rng_key = jax.random.PRNGKey(0)
data = np.load("sim_data_alt_careful.npz", allow_pickle=True)
all_sim_datasets = [np.asarray(ds, dtype=float) for ds in data["all_sim_datasets"]]
all_sum_qo = data["all_sum_qo"]
all_sum_qn = data["all_sum_qn"]
Y = int(data["Y"])
T = int(data["T"])
C = int(data["C"])
# Convert vectors to JAX arrays
year_vec = jnp.asarray(data["year_vec"])
time_vec = jnp.asarray(data["time_vec"])
caste_vec = jnp.asarray(data["caste_vec"])
with open("true_params_alt_careful.json", "r") as f:
true_params = json.load(f)
# Define model
def model(sim_counts,
year_vec,
time_vec,
caste_vec,
Y,
T,
sum_qo,
sum_qn):
# Priors
mu_qo0 = numpyro.sample("mu_qo0", dist.Normal(7.0, 2.0).expand([Y]))
mu_w0 = numpyro.sample("mu_w0", dist.Normal(20.0, 4.0).expand([Y]))
mu_m0 = numpyro.sample("mu_m0", dist.Normal(25.0, 3.0).expand([Y]))
mu_qn0 = numpyro.sample("mu_qn0", dist.Normal(25.0, 3.0).expand([Y]))
sigma = numpyro.sample("sigma", dist.Uniform(1.0, 2.0).expand([4]))
rho_w = numpyro.sample(
"rho_w",
dist.TruncatedNormal(20.0, 5.0, low=0.0).expand([Y])
)
rho_m = numpyro.sample(
"rho_m",
dist.TruncatedNormal(2.0, 0.5, low=0.0).expand([Y])
)
rho_qn = numpyro.sample(
"rho_qn",
dist.TruncatedNormal(2.0, 0.5, low=0.0).expand([Y])
)
ep = numpyro.sample("ep", dist.Uniform(0.49, 0.55).expand([Y - 1]))
mu = jnp.stack([mu_qo0, mu_w0, mu_m0, mu_qn0], axis=1) # (Y, 4)
rho_all = jnp.stack([rho_w, rho_m, rho_qn], axis=1)
# Create stacks for calculations
weeks = jnp.arange(T)
mu_stack = mu[:, None, :]
sigma_stack = sigma[None, None, :]
w_stack = weeks[None, :, None]
# CDF calculations
normal = dist.Normal(mu_stack, sigma_stack)
cdf_t = normal.cdf(w_stack)
cdf_prev = normal.cdf(w_stack - 1.0)
betas = cdf_t - cdf_prev
betas = betas.at[:, 0, :].set(normal.cdf(0.0)[:, 0, :])
betas = betas.at[:, T - 1, :].set(1.0 - (normal.cdf(T - 2.0)[:, 0, :]))
# Emergence
E0_qo = 500.0 * betas[0, :, 0] # Year 0
lam_qo = sum_qn[:-1, None] * ep[:, None] * betas[1:, :, 0]
lam_w = sum_qo[:, None] * rho_all[:, 0][:, None] * betas[:, :, 1]
lam_m = sum_qo[:, None] * rho_all[:, 1][:, None] * betas[:, :, 2]
lam_qn = sum_qo[:, None] * rho_all[:, 2][:, None] * betas[:, :, 3]
lam_qo_full = jnp.concatenate([E0_qo[None, :], lam_qo], axis=0)
E_mean_stack = jnp.stack([lam_qo_full, lam_w, lam_m, lam_qn], axis=2)
E_mean = E_mean_stack[year_vec, time_vec, caste_vec]
eps = 1e-9
E_mean_pos = jnp.maximum(E_mean,eps)
numpyro.sample("y_pred", dist.Poisson(E_mean_pos), obs=sim_counts)
def run_numpyro_model(
sim_counts,
year_vec,
time_vec,
caste_vec,
Y,
T,
sum_qo,
sum_qn,
rng_key
):
nuts = NUTS(model, target_accept_prob=0.95,init_strategy=init_to_sample())
mcmc = MCMC(nuts, num_warmup=10000, num_samples=20000,progress_bar=False,num_chains=2)
mcmc.run(
rng_key,
sim_counts=sim_counts,
year_vec=year_vec,
time_vec=time_vec,
caste_vec=caste_vec,
Y=Y,
T=T,
sum_qo=sum_qo,
sum_qn=sum_qn,
)
return mcmc
all_summaries = []
for i, sim_counts in enumerate(all_sim_datasets):
rng_key, subkey = jax.random.split(rng_key)
sim_counts = jnp.asarray(all_sim_datasets[i])
sum_qo_array = jnp.asarray(all_sum_qo[i,:])
sum_qn_array = jnp.asarray(all_sum_qn[i,:])
mcmc = run_numpyro_model(
sim_counts=sim_counts,
year_vec=year_vec,
time_vec=time_vec,
caste_vec=caste_vec,
Y=Y,
T=T,
sum_qo=sum_qo_array,
sum_qn=sum_qn_array,
rng_key=subkey,
)
true_dict = {k: np.asarray(v) for k, v in true_params.items()}
idata = az.from_numpyro(mcmc, constant_data=true_dict,log_likelihood=False)
# Traceplot
az.plot_trace(idata)
plt.savefig(f"traceplot_frankenstein2_{i}.png", dpi=200)
plt.close()
# Summary
summary = az.summary(idata, hdi_prob=0.95)
summary["dataset"] = i
summary["parameter"] = summary.index
summary["true_value"] = summary["parameter"].map(true_params)
summary["covered"] = (
(summary["true_value"] >= summary["hdi_2.5%"]) &
(summary["true_value"] <= summary["hdi_97.5%"])
)
all_summaries.append(summary)
big_summary_FT1 = pd.concat(all_summaries, ignore_index=True)
big_summary_FT1.to_csv("summary_frankenstein2.csv", index=False)