Latent Discrete variable in hierarchical logistic regression

Hi all,

I’m trying to build a hierarchical logistic regression with a latent discrete variable. I’m new to numpyro, so this may be simple. I’ve trimmed my code down to try and find a minimal example.

My data:

  • Multiple sites.

  • At each site, there are a variable number of observations with covariates x and binary outcomes y

  • If a site is inactive, all outcomes at the site will be 0.

  • If a site is active, then the outcomes follow a logistic regression:

    y_ij∼Bernoulli(σ(α_i + β_i*x_ij))

    with αsite, βsite drawn from hierarchical priors.

  • A discrete latent variable z, which depends on site-level covariates decides whether a site is inactive (0) or active (1).

Here’s a snippet to simulate data.

import numpy as np
from scipy.special import expit as sigmoid

# True hyperparameters
mu_alpha = 0.5
sigma_alpha = 0.3
mu_beta = 2.0
sigma_beta = 0.4
gamma_0 = -1.0  # Intercept for site activity
gamma_1 = 1.5   # Coefficient for site covariate
    
# Simulation settings
num_sites = 100
num_obs_per_site = 20
    
# Generate site-level covariates (e.g., elevation, habitat quality)
site_covariates = np.random.normal(0, 1, num_sites)
    
# Generate site activity states based on site covariates
theta_sites = sigmoid(gamma_0 + gamma_1 * site_covariates)
z_true = np.random.binomial(1, theta_sites, num_sites)
    
# Generate site-level parameters
alpha_true = np.random.normal(mu_alpha, sigma_alpha, num_sites)
beta_true = np.random.normal(mu_beta, sigma_beta, num_sites)
    
# Generate observation-level covariates and site indices
x = np.random.normal(0, 1, num_sites * num_obs_per_site)
site_idx = np.repeat(np.arange(num_sites), num_obs_per_site)
    
# generate y_ij
p_active = sigmoid(alpha_true[site_idx] + beta_true[site_idx] * x)
p_final = z_true[site_idx] * p_active # set to 0 if site is inactive
y = np.random.binomial(1, p_final)

And here’s where I got with the model. I have tried a bunch of things, and am not sure how to correctly account for z, my discrete latent variable.

def hierarchical_logistic_model(x, site_idx, site_covariates, num_sites, y=None):

    # Hyperpriors for site-level parameters
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0, 2))
    sigma_alpha = numpyro.sample("sigma_alpha", dist.HalfNormal(1))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0, 2))
    sigma_beta = numpyro.sample("sigma_beta", dist.HalfNormal(1))
    # Site activity state parameters
    gamma_0 = numpyro.sample("gamma_0", dist.Normal(0, 2))  # Intercept
    gamma_1 = numpyro.sample("gamma_1", dist.Normal(0, 2))  # Site covariate coefficient
    
    # Site-level logistic parameters 
    alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha), sample_shape=(num_sites,))
    beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta), sample_shape=(num_sites,))
   
    
    # probability sites are active
    theta_sites = numpyro.deterministic("theta_sites", 
                                       jax.nn.sigmoid(gamma_0 + gamma_1 * site_covariates))
    
    # Site activity states (latent binary variables)
    z = numpyro.sample("z", dist.Bernoulli(theta_sites), sample_shape=(num_sites,))
    
    # Observation-level probabilities
    logits_active = alpha[site_idx] + beta[site_idx] * x
    p_active = numpyro.deterministic("p_active", jax.nn.sigmoid(logits_active))
    
    # Final probability depends on site activity state
    p_final = numpyro.deterministic("p_final", z[site_idx] * p_active)
    
    # Likelihood
    if y is not None:
        numpyro.sample("y", dist.Bernoulli(p_final), obs=y)

# Initialize NUTS sampler
nuts_kernel = NUTS(hierarchical_logistic_model)
    
# Run MCMC
mcmc = MCMC(
        nuts_kernel,
        num_warmup=1000,
        num_samples=2000,
        num_chains=2,
        progress_bar=True)
    
 mcmc.run(x, site_idx, site_covariates, y)

Hi @louisfh,

NUTS (or simpler versions of HMC) cannot handle discrete parameters. You likely want to marginalize the discrete parameters out of the posterior distribution, and then run NUTS on resulting posterior.

See this example: Gaussian Mixture Model — NumPyro documentation

1 Like

Thanks. I have a working implementation now using DiscreteHMCGibbs, I’m posting the code in case this helps others. Now I’ve got past this problem, I’m onto the next one, which is making the hierarchical prior not a gaussian, but instead a gaussian process over site covariates.

def hierarchical_logistic_model(x, site_idx, site_covariates, n_sites, y=None):
    # Hyperpriors for site-level parameters
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0, 2))
    sigma_alpha = numpyro.sample("sigma_alpha", dist.HalfNormal(1))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0, 2))
    sigma_beta = numpyro.sample("sigma_beta", dist.HalfNormal(1))
    
    # Global site activity state parameters (shared across all sites)
    gamma_0 = numpyro.sample("gamma_0", dist.Normal(0, 2))  # Intercept
    gamma_1 = numpyro.sample("gamma_1", dist.Normal(0, 2))  # Site covariate coefficient

    alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha).expand([n_sites]))
    beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta).expand([n_sites]))
    
    # Site-level parameters using plates
    with numpyro.plate("sites", n_sites, dim=-1):
        
        # Site activity probabilities based on site covariates
        theta_sites = numpyro.deterministic("theta_sites", 
                                           jax.nn.sigmoid(gamma_0 + gamma_1 * site_covariates))
        
        # Site activity states (latent binary variables) - Mixed HMC will handle these
        z = numpyro.sample("z", dist.Bernoulli(theta_sites))

    # Observation-level likelihood - no plate needed since we're using site_idx
    logits_active = alpha[site_idx] + beta[site_idx] * x
    p_active = numpyro.deterministic("p_active", jax.nn.sigmoid(logits_active))
    p_final = numpyro.deterministic("p_final", z[site_idx] * p_active)

    numpyro.sample("y", dist.Bernoulli(p_final), obs=y)
    
    return {
        "mu_alpha": mu_alpha,
        "sigma_alpha": sigma_alpha,
        "mu_beta": mu_beta,
        "sigma_beta": sigma_beta,
        "alpha": alpha,
        "beta": beta,
        "gamma_0": gamma_0,
        "gamma_1": gamma_1,
        "theta_sites": theta_sites,
        "z": z,
        "p_active": p_active,
        "p_final": p_final
    }

def run_mcmc(x, site_idx, site_covariates, n_sites, y, num_warmup=1000, num_samples=2000, 
              num_chains=4, seed=42):
    """
    Run MCMC inference for the hierarchical logistic model.

    Returns:
    --------
    mcmc : MCMC object
        Fitted MCMC sampler
    """
    
    # Set JAX random seed
    jax.random.PRNGKey(seed)
    
    # Initialize Mixed HMC sampler with HMC as inner kernel for continuous parameters
    hmc_kernel = HMC(hierarchical_logistic_model)
    
    # now use DiscreteHMCGibbs as inner kernel for discrete parameters
    discrete_hmc_gibbs_kernel = DiscreteHMCGibbs(hmc_kernel)
    
    # Run MCMC
    mcmc = MCMC(
        discrete_hmc_gibbs_kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=True
    )
    
    mcmc.run(jax.random.PRNGKey(seed), x, site_idx, site_covariates, n_sites, y)
    
    return mcmc

The MH updates on the discrete parameters are unlikely to be efficient, but worth a try.

You should be able to use tinygp for defining the latent GP, see for e.g: Non-Gaussian Likelihoods — tinygp.

Thanks - if you’re saying the metropolis hasting sampling of z_i will be inefficient do you have another suggestion? I avoided marginalizing over z because I want to report posteriors of z, and didn’t know how to do that if marginalized over.

I mis-understood the details of DiscreteHMCGibbs. In the settings you chose, it looks like it would use Gibbs sampling on the discrete RVs (not MH, as I incorrectly assumed), which may be all that one can hope to do. Setting modified=True seems to minimize the self-transition probabilities, which could improve things. I’d be keen to see your chain statistics, if you’re willing to share.

(This recent paper contains ideas on better modifying the proposal distributions if the standard tools don’t work. Could be interesting additions to numpyro: [2403.18054] Modifying Gibbs sampling to avoid self transitions )