Good morning!
I’m writing to ask how the numpyro.infer.log_likelihood is computed when computing one step ahead with time varying parameters. In the example below, there are fixed effects (e.g., phi) for each time period that are estimated during MCMC. I train the model on a dataset that omits the last occasion. Then, I want to compute the log likelihood of the data for the last occasion (one step ahead). As such, the full history has one more period than the train history.
Which phi in samples['phi'] does NumPyro use to compute the log_likelihood for the last occasion? My suspicion is that it uses samples['phi'][:, -1]. Is that correct?
Thank you!
Phil
from jax import random
from jax.scipy.special import
from numpyro.contrib.control_flow import scan
from numpyro.handlers import seed
from numpyro.infer import NUTS, MCMC, log_likelihood
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
# hyperparameters
RANDOM_SEED = 89
# mcmc hyperparameters
CHAIN_COUNT = 1
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000
# simulation hyperparameters
OCCASION_COUNT = 7
SUPERPOPULATION_SIZE = 400
APPARENT_SURVIVAL = [0.5, 0.6, 0.7, 0.45, 0.65, 0.75,]
INITIAL_PI = 0.34
RECAPTURE_RATE = 0.8
M = 1000
def sim_js():
"""Simulation code ported from Kery and Schaub (2012), Chapter 10"""
rng = np.random.default_rng(RANDOM_SEED)
interval_count = OCCASION_COUNT - 1
# simulate entry into the population
pi_rest = (1 - INITIAL_PI) / interval_count
pi = np.concatenate([[INITIAL_PI], np.full(interval_count, pi_rest)])
# which occasion did the animal enter in?
entry_matrix = rng.multinomial(n=1, pvals=pi, size=SUPERPOPULATION_SIZE)
entry_occasion = entry_matrix.nonzero()[1]
_, entrant_count = np.unique(entry_occasion, return_counts=True)
# zero if the animal has not yet entered and one after it enters
entry_trajectory = np.maximum.accumulate(entry_matrix, axis=1)
# flip coins for survival between occasions
survival_draws = rng.binomial(
1, APPARENT_SURVIVAL, (SUPERPOPULATION_SIZE, interval_count)
)
# add column such that survival between t and t+1 implies alive at t+1
survival_draws = np.column_stack([np.ones(SUPERPOPULATION_SIZE), survival_draws])
# ensure that the animal survives until it enters
is_yet_to_enter = np.arange(OCCASION_COUNT) <= entry_occasion[:, None]
survival_draws[is_yet_to_enter] = 1
# once the survival_draws flips to zero the remaining row stays 0
survival_trajectory = np.cumprod(survival_draws, axis=1)
# animal has entered AND is still alive
state = entry_trajectory * survival_trajectory
# binary matrix of random possible recaptures
capture = rng.binomial(
1, RECAPTURE_RATE, (SUPERPOPULATION_SIZE, OCCASION_COUNT)
)
# remove the non-detected individuals
capture_history = state * capture
was_captured = capture_history.sum(axis=1) > 0
capture_history = capture_history[was_captured]
# augment the history with nz animals
n, _ = capture_history.shape
nz = M - n
all_zero_history = np.zeros((nz, OCCASION_COUNT))
capture_history = np.vstack([capture_history, all_zero_history]).astype(int)
# return a dict with relevant summary stats
N_t = state.sum(axis=0)
return {
'capture_history': capture_history,
'N_t': N_t,
'B': entrant_count,
}
def js_full(capture_history):
super_size, occasion_count = capture_history.shape
interval_count = occasion_count - 1
# constant recapture probability
p = numpyro.sample('p', dist.Uniform(0, 1))
# parameterize the entry probabilities in terms of pi and psi
psi = numpyro.sample('psi', dist.Uniform(0, 1))
pi = numpyro.sample('pi', dist.Dirichlet(jnp.ones(occasion_count)))
# compute the removal probabilities as a function of psi and pi
gamma = jnp.zeros(occasion_count)
# the `vector.at[0].set(1)` notation is jax for `vector[0] = 1`
gamma = gamma.at[0].set(psi * pi[0])
for t in range(1, occasion_count):
denominator = jnp.prod(1 - gamma[:t])
gamma = gamma.at[t].set(psi * pi[t] / denominator)
gamma = numpyro.deterministic('gamma', gamma)
with numpyro.plate('intervals', interval_count):
phi = numpyro.sample('phi', dist.Uniform(0, 1))
def transition_and_capture(carry, y_current):
z_previous, t = carry
trans_probs = jnp.array([
[1 - gamma[t], gamma[t], 0.0], # From not yet entered
[ 0.0, phi[t - 1], 1 - phi[t - 1]], # From alive
[ 0.0, 0.0, 1.0] # From dead
])
with numpyro.plate("animals", super_size, dim=-1):
mu_z_current = trans_probs[z_previous]
z_current = numpyro.sample(
"state",
dist.Categorical(dist.util.clamp_probs(mu_z_current)),
infer={"enumerate": "parallel"}
)
mu_y_current = jnp.where(z_current == 1, p, 0.0)
numpyro.sample(
"obs",
dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
obs=y_current
)
return (z_current, t + 1), None
state_init = jnp.zeros(super_size, dtype=jnp.int32)
scan(transition_and_capture, (state_init, 0),
jnp.swapaxes(capture_history, 0, 1))
# specify which sampler you want to use
nuts_kernel = NUTS(js_full)
# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
num_chains=CHAIN_COUNT)
# run the MCMC then inspect the output
full_history = sim_js()['capture_history']
train_history = full_history[:, :-1]
rng_key = random.PRNGKey(RANDOM_SEED)
mcmc.run(rng_key, train_history)
samples = mcmc.get_samples()
with seed(rng_seed=RANDOM_SEED):
result = log_likelihood(js_full, samples, capture_history=full_history)