Recover discrete latent states after enumerate, scan

Good morning!

I’m writing to ask for help with recovering discrete latent states. I have the following code, which runs great (and is heavily inspired by the CJS notebook on the NumPyro page, since it’s a very similar model).

Is there a handy way to recover discrete latent states, conditional on the data? For example, for an individual with a history of y = [0,1,1,0,1], we know that the individual was alive during y[3] because they were recaptured on y[4]. As such, the samples for z should be all(z[:, 3]==1). Conversely, they could have been alive during y[0], or not yet entered, so the samples for z[0] would either equal 0 or 1, depending on gamma[0] and p. I hope that makes sense!

Of course, I could compute these by hand. I’m just wondering if there’s a handy numpyro function. Thanks so much for your help! This package has really been a game changer for me.

Phil

from jax import random
from jax.scipy.special import expit
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC, Predictive
import arviz as az
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

# hyperparameters
RANDOM_SEED = 89

# mcmc hyperparameters
CHAIN_COUNT = 4
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000

# simulation hyperparameters
OCCASION_COUNT = 7
SUPERPOPULATION_SIZE = 400
APPARENT_SURVIVAL = 0.7
INITIAL_PI = 0.34
RECAPTURE_RATE = 0.5
M = 1000

def sim_js():
    """Simulation code ported from Kery and Schaub (2012), Chapter 10"""

    rng = np.random.default_rng(RANDOM_SEED)
    interval_count = OCCASION_COUNT - 1

    # simulate entry into the population
    pi_rest = (1 - INITIAL_PI) / interval_count
    pi = np.concatenate([[INITIAL_PI], np.full(interval_count, pi_rest)])

    # which occasion did the animal enter in?
    entry_matrix = rng.multinomial(n=1, pvals=pi, size=SUPERPOPULATION_SIZE)
    entry_occasion = entry_matrix.nonzero()[1]
    _, entrant_count = np.unique(entry_occasion, return_counts=True)

    # zero if the animal has not yet entered and one after it enters
    entry_trajectory = np.maximum.accumulate(entry_matrix, axis=1)

    # flip coins for survival between occasions
    survival_draws = rng.binomial(
        1, APPARENT_SURVIVAL, (SUPERPOPULATION_SIZE, interval_count)
    )

    # add column such that survival between t and t+1 implies alive at t+1
    survival_draws = np.column_stack([np.ones(SUPERPOPULATION_SIZE), survival_draws])

    # ensure that the animal survives until it enters
    is_yet_to_enter = np.arange(OCCASION_COUNT) <= entry_occasion[:, None]
    survival_draws[is_yet_to_enter] = 1

    # once the survival_draws flips to zero the remaining row stays 0
    survival_trajectory = np.cumprod(survival_draws, axis=1)

    # animal has entered AND is still alive
    state = entry_trajectory * survival_trajectory

    # binary matrix of random possible recaptures
    capture = rng.binomial(
        1, RECAPTURE_RATE, (SUPERPOPULATION_SIZE, OCCASION_COUNT)
    )

    # remove the non-detected individuals
    capture_history = state * capture
    was_captured = capture_history.sum(axis=1) > 0
    capture_history = capture_history[was_captured]

    # augment the history with nz animals
    n, _ = capture_history.shape
    nz = M - n
    all_zero_history = np.zeros((nz, OCCASION_COUNT))
    capture_history = np.vstack([capture_history, all_zero_history]).astype(int)

    # return a dict with relevant summary stats
    N_t = state.sum(axis=0)
    return {
        'capture_history': capture_history,
        'N_t': N_t,
        'B': entrant_count,
    }

def js_prior1(capture_history):

    super_size, occasion_count = capture_history.shape

    phi = numpyro.sample('phi', dist.Uniform(0, 1))
    p = numpyro.sample('p', dist.Uniform(0, 1))

    with numpyro.plate('intervals', occasion_count):
        gamma = numpyro.sample('gamma', dist.Uniform(0, 1))

    def transition_and_capture(carry, y_current):

        z_previous, t = carry

        # transition probability matrix
        trans_probs = jnp.array([
            [1 - gamma[t], gamma[t],     0.0],  # From not yet entered
            [         0.0,      phi, 1 - phi],  # From alive
            [         0.0,      0.0,     1.0]   # From dead
        ])

        with numpyro.plate("animals", super_size, dim=-1):

            # transition probabilities depend on current state
            mu_z_current = trans_probs[z_previous]
            z_current = numpyro.sample(
                "state",
                dist.Categorical(dist.util.clamp_probs(mu_z_current)),
                infer={"enumerate": "parallel"}
            )

            mu_y_current = jnp.where(z_current == 1, p, 0.0)
            numpyro.sample(
                "obs",
                dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
                obs=y_current
            )

        return (z_current, t + 1), None

    # start everyone in the not yet entered state
    state_init = jnp.zeros(super_size, dtype=jnp.int32)
    scan(
        transition_and_capture,
        (state_init, 0),
         jnp.swapaxes(capture_history, 0, 1)
    )

sim_results = sim_js()
capture_histories = sim_results['capture_history']

rng_key = random.PRNGKey(RANDOM_SEED)

# specify which sampler you want to use
nuts_kernel = NUTS(js_prior1)

# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
            num_chains=CHAIN_COUNT)

# run the MCMC then inspect the output
mcmc.run(rng_key, capture_histories)

You can use infer_discrete or using predictive = Predictive(model, posterior_samples, infer_discrete=True) like in the following tutorial: Example: Bayesian Models of Annotation — NumPyro documentation

Excellent! Thank you so much for your help. I’ve augmented the code with your suggestion (see below). Unfortunately, that raises an assertion error that I don’t quite understand. Do you have any guesses for the cause? Thanks again for your help!

# posterior predictive samples
posterior_samples = mcmc.get_samples()
predictive = Predictive(js_full, posterior_samples, infer_discrete=True)
discrete_samples = predictive(rng_key, full_history)

Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[9], line 16
     14 posterior_samples = mcmc.get_samples()
     15 predictive = Predictive(js_full, posterior_samples, infer_discrete=True)
---> 16 discrete_samples = predictive(rng_key, full_history)

File ~/miniforge3/envs/pm/lib/python3.13/site-packages/numpyro/infer/util.py:1082, in Predictive.__call__(self, rng_key, *args, **kwargs)
   1072 """
   1073 Returns dict of samples from the predictive distribution. By default, only sample sites not
   1074 contained in `posterior_samples` are returned. This can be modified by changing the
   (...)   1079 :param kwargs: model kwargs.
   1080 """
   1081 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1082     return self._call_with_params(rng_key, self.params, args, kwargs)
   1083 elif self.batch_ndims == 1:  # batch over parameters
   1084     batch_size = jnp.shape(jax.tree.flatten(self.params)[0][0])[0]

File ~/miniforge3/envs/pm/lib/python3.13/site-packages/numpyro/infer/util.py:1058, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
   1046     posterior_samples = _predictive(
   1047         guide_rng_key,
   1048         guide,
   (...)   1055         exclude_deterministic=self.exclude_deterministic,
   1056     )
   1057 model = substitute(self.model, self.params)
-> 1058 return _predictive(
   1059     rng_key,
   1060     model,
   1061     posterior_samples,
   1062     self._batch_shape,
   1063     return_sites=self.return_sites,
   1064     infer_discrete=self.infer_discrete,
   1065     parallel=self.parallel,
   1066     model_args=args,
   1067     model_kwargs=kwargs,
   1068     exclude_deterministic=self.exclude_deterministic,
   1069 )

File ~/miniforge3/envs/pm/lib/python3.13/site-packages/numpyro/infer/util.py:884, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, exclude_deterministic, model_args, model_kwargs)
    882 rng_key = rng_key.reshape(batch_shape + key_shape)
    883 chunk_size = num_samples if parallel else 1
--> 884 return soft_vmap(
    885     single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
    886 )

File ~/miniforge3/envs/pm/lib/python3.13/site-packages/numpyro/util.py:453, in soft_vmap(fn, xs, batch_ndims, chunk_size)
    447     xs = jax.tree.map(
    448         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
    449         xs,
    450     )
    451     fn = vmap(fn)
--> 453 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    454 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    455 ys = jax.tree.map(
    456     lambda y: jnp.reshape(
    457         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
    458     )[:batch_size],
    459     ys,
    460 )

    [... skipping hidden 13 frame]

File ~/miniforge3/envs/pm/lib/python3.13/site-packages/numpyro/infer/util.py:848, in _predictive.<locals>.single_prediction(val)
    846     model_trace = prototype_trace
    847     temperature = 1
--> 848     pred_samples = _sample_posterior(
    849         config_enumerate(substituted_model),
    850         first_available_dim,
    851         temperature,
    852         rng_key,
    853         *model_args,
    854         **model_kwargs,
    855     )
    856 else:
    857     model_trace = trace(seed(substituted_model, rng_key)).get_trace(
    858         *model_args, **model_kwargs
    859     )

File ~/miniforge3/envs/pm/lib/python3.13/site-packages/numpyro/contrib/funsor/discrete.py:75, in _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs)
     73     if node["infer"].get("enumerate") == "parallel":
     74         log_measure = approx_factors[log_measures[name]]
---> 75         value = _get_support_value(log_measure, name)
     76         node["value"] = funsor.to_data(
     77             value, name_to_dim=node["infer"]["name_to_dim"]
     78         )
     80 data = {
     81     name: site["value"]
     82     for name, site in sample_tr.items()
     83     if site["type"] == "sample"
     84 }

File ~/miniforge3/envs/pm/lib/python3.13/functools.py:934, in singledispatch.<locals>.wrapper(*args, **kw)
    931 if not args:
    932     raise TypeError(f'{funcname} requires at least '
    933                     '1 positional argument')
--> 934 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniforge3/envs/pm/lib/python3.13/site-packages/numpyro/contrib/funsor/discrete.py:31, in _get_support_value_contraction(funsor_dist, name, **kwargs)
     24 @_get_support_value.register(funsor.cnf.Contraction)
     25 def _get_support_value_contraction(funsor_dist, name, **kwargs):
     26     delta_terms = [
     27         v
     28         for v in funsor_dist.terms
     29         if isinstance(v, funsor.delta.Delta) and name in v.fresh
     30     ]
---> 31     assert len(delta_terms) == 1
     32     return _get_support_value(delta_terms[0], name, **kwargs)

AssertionError:

Sorry, it turns out that infer_discrete for scan is not supported yet: numpyro/test/contrib/test_infer_discrete.py at 3aed3876837b1829119b1baa1942d883df316e89 · pyro-ppl/numpyro · GitHub

You can use infer_discrete for loop instead by creating a new “loop” model for Predictive (it’s fine to use scan model for MCMC).

    state = state_init
    for t in range(occasion_count):
        with numpyro.handlers.scope(prefix=f"{t}", hide_types=["plate"]):
            (state, _), _ = transition_and_capture((state, t), capture_history[:, t])

Excellent! I hadn’t thought to use a loop for the predictive version but that makes perfect sense. Unfortunately, I’m still running into some issues, in that the returned state values appear to be essentially random. For example, the posterior median state for all animals on all occasions is 1. See below for my implementation. Where do you think I went wrong here?

import numpyro
numpyro.set_host_device_count(4)

from jax import random
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC, Predictive
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

# hyperparameters
RANDOM_SEED = 89

# mcmc hyperparameters
CHAIN_COUNT = 4
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000

# simulation hyperparameters
OCCASION_COUNT = 7
SUPERPOPULATION_SIZE = 400
APPARENT_SURVIVAL = 0.7
INITIAL_PI = 0.34
RECAPTURE_RATE = 0.5
M = 1000


def sim_js():
    """Simulation code ported from Kery and Schaub (2012), Chapter 10"""

    rng = np.random.default_rng(RANDOM_SEED)
    interval_count = OCCASION_COUNT - 1

    # simulate entry into the population
    pi_rest = (1 - INITIAL_PI) / interval_count
    pi = np.concatenate([[INITIAL_PI], np.full(interval_count, pi_rest)])

    # which occasion did the animal enter in?
    entry_matrix = rng.multinomial(n=1, pvals=pi, size=SUPERPOPULATION_SIZE)
    entry_occasion = entry_matrix.nonzero()[1]
    _, entrant_count = np.unique(entry_occasion, return_counts=True)

    # zero if the animal has not yet entered and one after it enters
    entry_trajectory = np.maximum.accumulate(entry_matrix, axis=1)

    # flip coins for survival between occasions
    survival_draws = rng.binomial(
        1, APPARENT_SURVIVAL, (SUPERPOPULATION_SIZE, interval_count)
    )

    # add column such that survival between t and t+1 implies alive at t+1
    survival_draws = np.column_stack([np.ones(SUPERPOPULATION_SIZE), survival_draws])

    # ensure that the animal survives until it enters
    is_yet_to_enter = np.arange(OCCASION_COUNT) <= entry_occasion[:, None]
    survival_draws[is_yet_to_enter] = 1

    # once the survival_draws flips to zero the remaining row stays 0
    survival_trajectory = np.cumprod(survival_draws, axis=1)

    # animal has entered AND is still alive
    state = entry_trajectory * survival_trajectory

    # binary matrix of random possible recaptures
    capture = rng.binomial(
        1, RECAPTURE_RATE, (SUPERPOPULATION_SIZE, OCCASION_COUNT)
    )

    # remove the non-detected individuals
    capture_history = state * capture
    was_captured = capture_history.sum(axis=1) > 0
    capture_history = capture_history[was_captured]

    # augment the history with nz animals
    n, _ = capture_history.shape
    nz = M - n
    all_zero_history = np.zeros((nz, OCCASION_COUNT))
    capture_history = np.vstack([capture_history, all_zero_history]).astype(int)

    # return a dict with relevant summary stats
    N_t = state.sum(axis=0)
    return {
        'capture_history': capture_history,
        'N_t': N_t,
        'B': entrant_count,
    }

def js_scan(capture_history):

    super_size, occasion_count = capture_history.shape

    phi = numpyro.sample('phi', dist.Uniform(0, 1))
    p = numpyro.sample('p', dist.Uniform(0, 1))

    # parameterize the  entry probabilities in terms of pi and psi
    psi = numpyro.sample('psi', dist.Uniform(0, 1))
    pi = numpyro.sample('pi', dist.Dirichlet(jnp.ones(occasion_count)))

    # compute the removal probabilities as a function of psi and pi
    gamma = jnp.zeros(occasion_count)

    # the `vector.at[0].set(1)` notation is jax for `vector[0] = 1`
    gamma = gamma.at[0].set(psi * pi[0])
    for t in range(1, occasion_count):
        denominator = jnp.prod(1 - gamma[:t])
        gamma = gamma.at[t].set(psi * pi[t] / denominator)
    gamma = numpyro.deterministic('gamma', gamma)

    def transition_and_capture(carry, y_current):

        z_previous, t = carry

        trans_probs = jnp.array([
            [1 - gamma[t], gamma[t], 0.0],  # not yet entered
            [0.0, phi, 1 - phi],            # alive
            [0.0, 0.0, 1.0]                 # dead
        ])

        with numpyro.plate("animals", super_size, dim=-1):

            mu_z_current = trans_probs[z_previous]
            z_current = numpyro.sample(
                "state",
                dist.Categorical(dist.util.clamp_probs(mu_z_current)),
                infer={"enumerate": "parallel"}
            )

            mu_y_current = jnp.where(z_current == 1, p, 0.0)
            numpyro.sample(
                "obs",
                dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
                obs=y_current
            )

        return (z_current, t + 1), None

    # start everyone in the not yet entered state
    state_init = jnp.zeros(super_size, dtype=jnp.int32)
    scan(
        transition_and_capture,
        (state_init, 0),
         jnp.swapaxes(capture_history, 0, 1)
    )

def js_loop(capture_history):

    super_size, occasion_count = capture_history.shape

    phi = numpyro.sample('phi', dist.Uniform(0, 1))
    p = numpyro.sample('p', dist.Uniform(0, 1))

    # parameterize the  entry probabilities in terms of pi and psi
    psi = numpyro.sample('psi', dist.Uniform(0, 1))
    pi = numpyro.sample('pi', dist.Dirichlet(jnp.ones(occasion_count)))

    # compute the removal probabilities as a function of psi and pi
    gamma = jnp.zeros(occasion_count)

    # the `vector.at[0].set(1)` notation is jax for `vector[0] = 1`
    gamma = gamma.at[0].set(psi * pi[0])
    for t in range(1, occasion_count):
        denominator = jnp.prod(1 - gamma[:t])
        gamma = gamma.at[t].set(psi * pi[t] / denominator)
    gamma = numpyro.deterministic('gamma', gamma)

    def transition_and_capture(carry, y_current):

        z_previous, t = carry

        trans_probs = jnp.array([
            [1 - gamma[t], gamma[t], 0.0],  # not yet entered
            [0.0, phi, 1 - phi],            # alive
            [0.0, 0.0, 1.0]                 # dead
        ])

        with numpyro.plate("animals", super_size, dim=-1):

            mu_z_current = trans_probs[z_previous]
            z_current = numpyro.sample(
                "state",
                dist.Categorical(dist.util.clamp_probs(mu_z_current)),
                infer={"enumerate": "parallel"}
            )

            mu_y_current = jnp.where(z_current == 1, p, 0.0)
            numpyro.sample(
                "obs",
                dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
                obs=y_current
            )

        return (z_current, t + 1), None
    
    state_init = jnp.zeros(super_size, dtype=jnp.int32)
    state = state_init

    for t in range(occasion_count):
        with numpyro.handlers.scope(prefix=f"{t}", hide_types=["plate"]):
            (state, _), _ = transition_and_capture((state, t), capture_history[:, t])

# simulate data
sim_results = sim_js()
capture_histories = sim_results['capture_history']

# specify which sampler you want to use
nuts_kernel = NUTS(js_scan)

# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
            num_chains=CHAIN_COUNT)

# run the MCMC then inspect the output
rng_key = random.PRNGKey(RANDOM_SEED)
mcmc.run(rng_key, capture_histories)

# posterior predictive samples
posterior_samples = mcmc.get_samples()
predictive = Predictive(js_loop, posterior_samples, infer_discrete=True)
discrete_samples = predictive(rng_key, capture_histories)
states = np.stack([discrete_samples[f"{t}/state"] for t in range(OCCASION_COUNT)], axis=-1)

state_hat = np.median(states, axis=0)
(state_hat == 1).all()

I think you can try to fix the parameters first and use predictive to see if it can recover the state. Maybe make trans_probs constant across all z_previous to make observations independent. If it works as expected then allowing for more general trans_prob

Okay cool! I try that and see if I can’t isolate the problem. Thanks!

For what it’s worth, ChatGPT suggested an approach with infer_discrete(), which produces very reasonable posterior predictions for the state, albeit very slowly since it loops over each sample (see below). Is there a way to somehow speed up this approach, or merge it with the version using Predictive?

Thanks a ton for your help with this! I’m still acquainting myself with numpyro, but I’m amazed at how quickly I can fit models that normally take days!

*edited code snippet to include full script

import numpyro
numpyro.set_host_device_count(4)

from jax import random
from numpyro import handlers
from numpyro.contrib.control_flow import scan
from numpyro.contrib.funsor import infer_discrete
from numpyro.infer import NUTS, MCMC, Predictive
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist

# hyperparameters
RANDOM_SEED = 89

# mcmc hyperparameters
CHAIN_COUNT = 4
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000

# simulation hyperparameters
OCCASION_COUNT = 7
SUPERPOPULATION_SIZE = 400
APPARENT_SURVIVAL = 0.7
INITIAL_PI = 0.34
RECAPTURE_RATE = 0.5
M = 1000

def quick_diagnostic(pred_samples, header):
    """Compares two histories (die early [9], born late [16]) with predictions"""
    print('\n', header)
    print('True history [9]:  ', capture_histories[9])
    print('Pred history [9]:  ', pred_samples[9].astype(int))
    print('True history [16]: ', capture_histories[16])
    print('Pred history [16]: ', pred_samples[16].astype(int))

def sim_js():
    """Simulation code ported from Kery and Schaub (2012), Chapter 10"""

    rng = np.random.default_rng(RANDOM_SEED)
    interval_count = OCCASION_COUNT - 1

    # simulate entry into the population
    pi_rest = (1 - INITIAL_PI) / interval_count
    pi = np.concatenate([[INITIAL_PI], np.full(interval_count, pi_rest)])

    # which occasion did the animal enter in?
    entry_matrix = rng.multinomial(n=1, pvals=pi, size=SUPERPOPULATION_SIZE)
    entry_occasion = entry_matrix.nonzero()[1]
    _, entrant_count = np.unique(entry_occasion, return_counts=True)

    # zero if the animal has not yet entered and one after it enters
    entry_trajectory = np.maximum.accumulate(entry_matrix, axis=1)

    # flip coins for survival between occasions
    survival_draws = rng.binomial(
        1, APPARENT_SURVIVAL, (SUPERPOPULATION_SIZE, interval_count)
    )

    # add column such that survival between t and t+1 implies alive at t+1
    survival_draws = np.column_stack([np.ones(SUPERPOPULATION_SIZE), survival_draws])

    # ensure that the animal survives until it enters
    is_yet_to_enter = np.arange(OCCASION_COUNT) <= entry_occasion[:, None]
    survival_draws[is_yet_to_enter] = 1

    # once the survival_draws flips to zero the remaining row stays 0
    survival_trajectory = np.cumprod(survival_draws, axis=1)

    # animal has entered AND is still alive
    state = entry_trajectory * survival_trajectory

    # binary matrix of random possible recaptures
    capture = rng.binomial(
        1, RECAPTURE_RATE, (SUPERPOPULATION_SIZE, OCCASION_COUNT)
    )

    # remove the non-detected individuals
    capture_history = state * capture
    was_captured = capture_history.sum(axis=1) > 0
    capture_history = capture_history[was_captured]

    # augment the history with nz animals
    n, _ = capture_history.shape
    nz = M - n
    all_zero_history = np.zeros((nz, OCCASION_COUNT))
    capture_history = np.vstack([capture_history, all_zero_history]).astype(int)

    # return a dict with relevant summary stats
    N_t = state.sum(axis=0)
    return {
        'capture_history': capture_history,
        'N_t': N_t,
        'B': entrant_count,
    }

def js_scan(capture_history):

    super_size, occasion_count = capture_history.shape

    phi = numpyro.sample('phi', dist.Uniform(0, 1))
    p = numpyro.sample('p', dist.Uniform(0, 1))

    # parameterize the  entry probabilities in terms of pi and psi
    psi = numpyro.sample('psi', dist.Uniform(0, 1))
    pi = numpyro.sample('pi', dist.Dirichlet(jnp.ones(occasion_count)))

    # compute the removal probabilities as a function of psi and pi
    gamma = jnp.zeros(occasion_count)

    # the `vector.at[0].set(1)` notation is jax for `vector[0] = 1`
    gamma = gamma.at[0].set(psi * pi[0])
    for t in range(1, occasion_count):
        denominator = jnp.prod(1 - gamma[:t])
        gamma = gamma.at[t].set(psi * pi[t] / denominator)
    gamma = numpyro.deterministic('gamma', gamma)

    def transition_and_capture(carry, y_current):

        z_previous, t = carry

        trans_probs = jnp.array([
            [1 - gamma[t], gamma[t], 0.0],  # not yet entered
            [0.0, phi, 1 - phi],            # alive
            [0.0, 0.0, 1.0]                 # dead
        ])

        with numpyro.plate("animals", super_size, dim=-1):

            mu_z_current = trans_probs[z_previous]
            z_current = numpyro.sample(
                "state",
                dist.Categorical(dist.util.clamp_probs(mu_z_current)),
                infer={"enumerate": "parallel"}
            )

            mu_y_current = jnp.where(z_current == 1, p, 0.0)
            numpyro.sample(
                "obs",
                dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
                obs=y_current
            )

        return (z_current, t + 1), None

    # start everyone in the not yet entered state
    state_init = jnp.zeros(super_size, dtype=jnp.int32)
    scan(
        transition_and_capture,
        (state_init, 0),
         jnp.swapaxes(capture_history, 0, 1)
    )

def js_loop(capture_history):

    super_size, occasion_count = capture_history.shape

    phi = numpyro.sample('phi', dist.Uniform(0, 1))
    p = numpyro.sample('p', dist.Uniform(0, 1))

    # parameterize the  entry probabilities in terms of pi and psi
    psi = numpyro.sample('psi', dist.Uniform(0, 1))
    pi = numpyro.sample('pi', dist.Dirichlet(jnp.ones(occasion_count)))

    # compute the removal probabilities as a function of psi and pi
    gamma = jnp.zeros(occasion_count)

    # the `vector.at[0].set(1)` notation is jax for `vector[0] = 1`
    gamma = gamma.at[0].set(psi * pi[0])
    for t in range(1, occasion_count):
        denominator = jnp.prod(1 - gamma[:t])
        gamma = gamma.at[t].set(psi * pi[t] / denominator)
    gamma = numpyro.deterministic('gamma', gamma)

    def transition_and_capture(carry, y_current):

        z_previous, t = carry

        trans_probs = jnp.array([
            [1 - gamma[t], gamma[t], 0.0],  # not yet entered
            [0.0, phi, 1 - phi],            # alive
            [0.0, 0.0, 1.0]                 # dead
        ])

        with numpyro.plate("animals", super_size, dim=-1):

            mu_z_current = trans_probs[z_previous]
            z_current = numpyro.sample(
                "state",
                dist.Categorical(dist.util.clamp_probs(mu_z_current)),
                infer={"enumerate": "parallel"}
            )

            mu_y_current = jnp.where(z_current == 1, p, 0.0)
            numpyro.sample(
                "obs",
                dist.Bernoulli(dist.util.clamp_probs(mu_y_current)),
                obs=y_current
            )

        return (z_current, t + 1), None

    state_init = jnp.zeros(super_size, dtype=jnp.int32)
    state = state_init

    for t in range(occasion_count):
        with numpyro.handlers.scope(prefix=f"{t}", hide_types=["plate"]):
            (state, _), _ = transition_and_capture((state, t), capture_history[:, t])

# simulate data
sim_results = sim_js()
capture_histories = sim_results['capture_history']

# specify which sampler you want to use
nuts_kernel = NUTS(js_scan)

# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
            num_chains=CHAIN_COUNT)

# run the MCMC then inspect the output
rng_key = random.PRNGKey(RANDOM_SEED)
mcmc.run(rng_key, capture_histories)

# posterior predictive samples
posterior_samples = mcmc.get_samples()
predictive = Predictive(js_loop, posterior_samples, infer_discrete=True)
discrete_samples = predictive(rng_key, capture_histories)
states = np.stack([discrete_samples[f"{t}/state"] for t in range(OCCASION_COUNT)], axis=-1)
state_hat_predictive = np.median(states, axis=0)

quick_diagnostic(state_hat_predictive, 'Predictive Version')

# slow version, about 6 iter per second
# total_samples = len(posterior_samples['psi'])
total_samples = 100

all_states = []
for i in range(total_samples):
    conditioned_values = {k: v[i] for k, v in posterior_samples.items()}

    trace = handlers.trace(
        infer_discrete(
            handlers.condition(js_loop, conditioned_values),
            temperature=1,
            rng_key=random.PRNGKey(i),
        )
    ).get_trace(capture_histories)

    states_i = np.stack(
        [trace[f"{t}/state"]["value"] for t in range(OCCASION_COUNT)], axis=-1
    )
    all_states.append(states_i)

all_states = np.stack(all_states, axis=0)
state_hat_infer = np.median(all_states, axis=0)
quick_diagnostic(state_hat_infer, 'Slow version')

Interesting! In Predictive, we do infer discrete with temperature=1 under the hood. I’ll take a look.

1 Like