Good morning!
I’m writing to ask for help with recovering discrete latent states. I have the following code, which runs great (and is heavily inspired by the CJS notebook on the NumPyro page, since it’s a very similar model).
Is there a handy way to recover discrete latent states, conditional on the data? For example, for an individual with a history of y = [0,1,1,0,1], we know that the individual was alive during y[3] because they were recaptured on y[4]. As such, the samples for z should be all(z[:, 3]==1). Conversely, they could have been alive during y[0], or not yet entered, so the samples for z[0] would either equal 0 or 1, depending on gamma[0] and p. I hope that makes sense!
Of course, I could compute these by hand. I’m just wondering if there’s a handy numpyro function. Thanks so much for your help! This package has really been a game changer for me.
Phil
from jax import random
from jax.scipy.special import expit
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC, Predictive
import arviz as az
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
# hyperparameters
RANDOM_SEED = 89
# mcmc hyperparameters
CHAIN_COUNT = 4
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000
# simulation hyperparameters
OCCASION_COUNT = 7
SUPERPOPULATION_SIZE = 400
APPARENT_SURVIVAL = 0.7
INITIAL_PI = 0.34
RECAPTURE_RATE = 0.5
M = 1000
def sim_js():
"""Simulation code ported from Kery and Schaub (2012), Chapter 10"""
rng = np.random.default_rng(RANDOM_SEED)
interval_count = OCCASION_COUNT - 1
# simulate entry into the population
pi_rest = (1 - INITIAL_PI) / interval_count
pi = np.concatenate([[INITIAL_PI], np.full(interval_count, pi_rest)])
# which occasion did the animal enter in?
entry_matrix = rng.multinomial(n=1, pvals=pi, size=SUPERPOPULATION_SIZE)
entry_occasion = entry_matrix.nonzero()[1]
_, entrant_count = np.unique(entry_occasion, return_counts=True)
# zero if the animal has not yet entered and one after it enters
entry_trajectory = np.maximum.accumulate(entry_matrix, axis=1)
# flip coins for survival between occasions
survival_draws = rng.binomial(
1, APPARENT_SURVIVAL, (SUPERPOPULATION_SIZE, interval_count)
)
# add column such that survival between t and t+1 implies alive at t+1
survival_draws = np.column_stack([np.ones(SUPERPOPULATION_SIZE), survival_draws])
# ensure that the animal survives until it enters
is_yet_to_enter = np.arange(OCCASION_COUNT) <= entry_occasion[:, None]
survival_draws[is_yet_to_enter] = 1
# once the survival_draws flips to zero the remaining row stays 0
survival_trajectory = np.cumprod(survival_draws, axis=1)
# animal has entered AND is still alive
state = entry_trajectory * survival_trajectory
# binary matrix of random possible recaptures
capture = rng.binomial(
1, RECAPTURE_RATE, (SUPERPOPULATION_SIZE, OCCASION_COUNT)
)
# remove the non-detected individuals
capture_history = state * capture
was_captured = capture_history.sum(axis=1) > 0
capture_history = capture_history[was_captured]
# augment the history with nz animals
n, _ = capture_history.shape
nz = M - n
all_zero_history = np.zeros((nz, OCCASION_COUNT))
capture_history = np.vstack([capture_history, all_zero_history]).astype(int)
# return a dict with relevant summary stats
N_t = state.sum(axis=0)
return {
'capture_history': capture_history,
'N_t': N_t,
'B': entrant_count,
}
def js_prior1(capture_history):
super_size, occasion_count = capture_history.shape
phi = numpyro.sample('phi', dist.Uniform(0, 1))
p = numpyro.sample('p', dist.Uniform(0, 1))
with numpyro.plate('intervals', occasion_count):
gamma = numpyro.sample('gamma', dist.Uniform(0, 1))
def transition_and_capture(carry, y_current):
z_previous, t = carry
# transition probability matrix
trans_probs = jnp.array([
[1 - gamma[t], gamma[t], 0.0], # From not yet entered
[ 0.0, phi, 1 - phi], # From alive
[ 0.0, 0.0, 1.0] # From dead
])
with numpyro.plate("animals", super_size, dim=-1):
# transition probabilities depend on current state
mu_z_current = trans_probs[z_previous]
z_current = numpyro.sample(
"state",
dist.Categorical(dist.util.clamp_probs(mu_z_current)),
infer={"enumerate": "parallel"}
)
mu_y_current = jnp.where(z_current == 1, p, 0.0)
numpyro.sample(
"obs",
dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
obs=y_current
)
return (z_current, t + 1), None
# start everyone in the not yet entered state
state_init = jnp.zeros(super_size, dtype=jnp.int32)
scan(
transition_and_capture,
(state_init, 0),
jnp.swapaxes(capture_history, 0, 1)
)
sim_results = sim_js()
capture_histories = sim_results['capture_history']
rng_key = random.PRNGKey(RANDOM_SEED)
# specify which sampler you want to use
nuts_kernel = NUTS(js_prior1)
# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
num_chains=CHAIN_COUNT)
# run the MCMC then inspect the output
mcmc.run(rng_key, capture_histories)