Hi all,
I’m trying to build a hierarchical logistic regression with a latent discrete variable. I’m new to numpyro, so this may be simple. I’ve trimmed my code down to try and find a minimal example.
My data:
-
Multiple sites.
-
At each site, there are a variable number of observations with covariates
xand binary outcomesy -
If a site is inactive, all outcomes at the site will be 0.
-
If a site is active, then the outcomes follow a logistic regression:
y_ij∼Bernoulli(σ(α_i + β_i*x_ij))
with αsite, βsite drawn from hierarchical priors.
-
A discrete latent variable z, which depends on site-level covariates decides whether a site is inactive (0) or active (1).
Here’s a snippet to simulate data.
import numpy as np
from scipy.special import expit as sigmoid
# True hyperparameters
mu_alpha = 0.5
sigma_alpha = 0.3
mu_beta = 2.0
sigma_beta = 0.4
gamma_0 = -1.0 # Intercept for site activity
gamma_1 = 1.5 # Coefficient for site covariate
# Simulation settings
num_sites = 100
num_obs_per_site = 20
# Generate site-level covariates (e.g., elevation, habitat quality)
site_covariates = np.random.normal(0, 1, num_sites)
# Generate site activity states based on site covariates
theta_sites = sigmoid(gamma_0 + gamma_1 * site_covariates)
z_true = np.random.binomial(1, theta_sites, num_sites)
# Generate site-level parameters
alpha_true = np.random.normal(mu_alpha, sigma_alpha, num_sites)
beta_true = np.random.normal(mu_beta, sigma_beta, num_sites)
# Generate observation-level covariates and site indices
x = np.random.normal(0, 1, num_sites * num_obs_per_site)
site_idx = np.repeat(np.arange(num_sites), num_obs_per_site)
# generate y_ij
p_active = sigmoid(alpha_true[site_idx] + beta_true[site_idx] * x)
p_final = z_true[site_idx] * p_active # set to 0 if site is inactive
y = np.random.binomial(1, p_final)
And here’s where I got with the model. I have tried a bunch of things, and am not sure how to correctly account for z, my discrete latent variable.
def hierarchical_logistic_model(x, site_idx, site_covariates, num_sites, y=None):
# Hyperpriors for site-level parameters
mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0, 2))
sigma_alpha = numpyro.sample("sigma_alpha", dist.HalfNormal(1))
mu_beta = numpyro.sample("mu_beta", dist.Normal(0, 2))
sigma_beta = numpyro.sample("sigma_beta", dist.HalfNormal(1))
# Site activity state parameters
gamma_0 = numpyro.sample("gamma_0", dist.Normal(0, 2)) # Intercept
gamma_1 = numpyro.sample("gamma_1", dist.Normal(0, 2)) # Site covariate coefficient
# Site-level logistic parameters
alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha), sample_shape=(num_sites,))
beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta), sample_shape=(num_sites,))
# probability sites are active
theta_sites = numpyro.deterministic("theta_sites",
jax.nn.sigmoid(gamma_0 + gamma_1 * site_covariates))
# Site activity states (latent binary variables)
z = numpyro.sample("z", dist.Bernoulli(theta_sites), sample_shape=(num_sites,))
# Observation-level probabilities
logits_active = alpha[site_idx] + beta[site_idx] * x
p_active = numpyro.deterministic("p_active", jax.nn.sigmoid(logits_active))
# Final probability depends on site activity state
p_final = numpyro.deterministic("p_final", z[site_idx] * p_active)
# Likelihood
if y is not None:
numpyro.sample("y", dist.Bernoulli(p_final), obs=y)
# Initialize NUTS sampler
nuts_kernel = NUTS(hierarchical_logistic_model)
# Run MCMC
mcmc = MCMC(
nuts_kernel,
num_warmup=1000,
num_samples=2000,
num_chains=2,
progress_bar=True)
mcmc.run(x, site_idx, site_covariates, y)