One step ahead loglik with time varying parameters

Good morning!

I’m writing to ask how the numpyro.infer.log_likelihood is computed when computing one step ahead with time varying parameters. In the example below, there are fixed effects (e.g., phi) for each time period that are estimated during MCMC. I train the model on a dataset that omits the last occasion. Then, I want to compute the log likelihood of the data for the last occasion (one step ahead). As such, the full history has one more period than the train history.

Which phi in samples['phi'] does NumPyro use to compute the log_likelihood for the last occasion? My suspicion is that it uses samples['phi'][:, -1]. Is that correct?

Thank you!
Phil

from jax import random
from jax.scipy.special import
from numpyro.contrib.control_flow import scan
from numpyro.handlers import seed
from numpyro.infer import NUTS, MCMC, log_likelihood
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

# hyperparameters
RANDOM_SEED = 89

# mcmc hyperparameters
CHAIN_COUNT = 1
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000

# simulation hyperparameters
OCCASION_COUNT = 7
SUPERPOPULATION_SIZE = 400
APPARENT_SURVIVAL = [0.5, 0.6, 0.7, 0.45, 0.65, 0.75,]
INITIAL_PI = 0.34
RECAPTURE_RATE = 0.8
M = 1000

def sim_js():
    """Simulation code ported from Kery and Schaub (2012), Chapter 10"""

    rng = np.random.default_rng(RANDOM_SEED)
    interval_count = OCCASION_COUNT - 1

    # simulate entry into the population
    pi_rest = (1 - INITIAL_PI) / interval_count
    pi = np.concatenate([[INITIAL_PI], np.full(interval_count, pi_rest)])

    # which occasion did the animal enter in?
    entry_matrix = rng.multinomial(n=1, pvals=pi, size=SUPERPOPULATION_SIZE)
    entry_occasion = entry_matrix.nonzero()[1]
    _, entrant_count = np.unique(entry_occasion, return_counts=True)

    # zero if the animal has not yet entered and one after it enters
    entry_trajectory = np.maximum.accumulate(entry_matrix, axis=1)

    # flip coins for survival between occasions
    survival_draws = rng.binomial(
        1, APPARENT_SURVIVAL, (SUPERPOPULATION_SIZE, interval_count)
    )

    # add column such that survival between t and t+1 implies alive at t+1
    survival_draws = np.column_stack([np.ones(SUPERPOPULATION_SIZE), survival_draws])

    # ensure that the animal survives until it enters
    is_yet_to_enter = np.arange(OCCASION_COUNT) <= entry_occasion[:, None]
    survival_draws[is_yet_to_enter] = 1

    # once the survival_draws flips to zero the remaining row stays 0
    survival_trajectory = np.cumprod(survival_draws, axis=1)

    # animal has entered AND is still alive
    state = entry_trajectory * survival_trajectory

    # binary matrix of random possible recaptures
    capture = rng.binomial(
        1, RECAPTURE_RATE, (SUPERPOPULATION_SIZE, OCCASION_COUNT)
    )

    # remove the non-detected individuals
    capture_history = state * capture
    was_captured = capture_history.sum(axis=1) > 0
    capture_history = capture_history[was_captured]

    # augment the history with nz animals
    n, _ = capture_history.shape
    nz = M - n
    all_zero_history = np.zeros((nz, OCCASION_COUNT))
    capture_history = np.vstack([capture_history, all_zero_history]).astype(int)

    # return a dict with relevant summary stats
    N_t = state.sum(axis=0)
    return {
        'capture_history': capture_history,
        'N_t': N_t,
        'B': entrant_count,
    }

def js_full(capture_history):

    super_size, occasion_count = capture_history.shape
    interval_count = occasion_count - 1

    # constant recapture probability
    p = numpyro.sample('p', dist.Uniform(0, 1))

    # parameterize the  entry probabilities in terms of pi and psi
    psi = numpyro.sample('psi', dist.Uniform(0, 1))
    pi = numpyro.sample('pi', dist.Dirichlet(jnp.ones(occasion_count)))

    # compute the removal probabilities as a function of psi and pi
    gamma = jnp.zeros(occasion_count)

    # the `vector.at[0].set(1)` notation is jax for `vector[0] = 1`
    gamma = gamma.at[0].set(psi * pi[0])
    for t in range(1, occasion_count):
        denominator = jnp.prod(1 - gamma[:t])
        gamma = gamma.at[t].set(psi * pi[t] / denominator)
    gamma = numpyro.deterministic('gamma', gamma)

    with numpyro.plate('intervals', interval_count):
        phi = numpyro.sample('phi', dist.Uniform(0, 1))

    def transition_and_capture(carry, y_current):

        z_previous, t = carry

        trans_probs = jnp.array([
            [1 - gamma[t],   gamma[t],            0.0], # From not yet entered
            [         0.0, phi[t - 1], 1 - phi[t - 1]], # From alive
            [         0.0,        0.0,            1.0]  # From dead
        ])

        with numpyro.plate("animals", super_size, dim=-1):

            mu_z_current = trans_probs[z_previous]
            z_current = numpyro.sample(
                "state",
                dist.Categorical(dist.util.clamp_probs(mu_z_current)),
                infer={"enumerate": "parallel"}
            )

            mu_y_current = jnp.where(z_current == 1, p, 0.0)
            numpyro.sample(
                "obs",
                dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
                obs=y_current
            )

        return (z_current, t + 1), None

    state_init = jnp.zeros(super_size, dtype=jnp.int32)
    scan(transition_and_capture, (state_init, 0),
         jnp.swapaxes(capture_history, 0, 1))

# specify which sampler you want to use
nuts_kernel = NUTS(js_full)

# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
            num_chains=CHAIN_COUNT)

# run the MCMC then inspect the output
full_history = sim_js()['capture_history']
train_history = full_history[:, :-1]

rng_key = random.PRNGKey(RANDOM_SEED)
mcmc.run(rng_key, train_history)

samples = mcmc.get_samples()
with seed(rng_seed=RANDOM_SEED):
    result = log_likelihood(js_full, samples, capture_history=full_history)

During inference, we marginalize out state. For log likelihood computation, you might want to obtain the state first (conditioned on full_history) via infer_discrete (see Recover discrete latent states after enumerate, scan - #2 by fehiepsi ). After adding state to the samples, I think you can use log_likelihood like what you did.