Short version of the problem
I have two implementations of a simple model. The one using numpyro’s scan gives wrong results and I would like to know what I’m doing wrong.
Detailed explanation with MWE
I’m working on a Capture-Recapture-type model where the probability of being captured on each of the T capture occasions depends on whether the animal was captured in the previous occasion.
The observed data looks like this for T=5:
where each row is the capture history of each animal that was ever captured during the T capture occasions.
The model uses data augmentation – effectively appending rows of zeros to Y – and then including a latent binary indicator z_i which is 1 if that animal is included in the original sample and 0 otherwise (the purpose of this is to do inference on N, the true size of the population).
I’ve implemented a version of this model using numpyro’s scan
which gives wrong results and another version using vectorized operation which gives the correct results. I would like to understand what I’m doing wrong with scan.
Here’s the scan version:
# Model using loop (scan) over capture occasions:
def model_6_2_3(nz=150, predict=False, yobs=None):
C, T = yobs.shape if yobs is not None else (100, 5)
M = C + nz
# Augment data by adding rows of zeros:
yaug = jnp.vstack([yobs, jnp.zeros(shape=(nz, T))]) if yobs is not None else None
# Priors:
omega = npr.sample("omega", dist.Beta(3, 3)) # Probability of inclusion, P(z=1).
p = npr.sample("p", dist.Beta(3, 3)) # Prob. of capture if not captured last period.
c = npr.sample("c", dist.Beta(3, 3)) # Prob. of capture if captured last period.
plate_animals = npr.plate("Animals", M, dim=-1)
# Determine which animal is part of the population:
with plate_animals:
z = npr.sample("z", dist.Bernoulli(omega), infer={"enumerate": "parallel"})
# The transition function will be applied to each column of yaug.
def transition_fn(yprev, y):
with plate_animals:
p_eff = z * ((1 - yprev) * p + yprev * c)
npr.sample("yaug", dist.Bernoulli(p_eff), obs=y)
return y, None
yinit = jnp.zeros((M,))
scan(
transition_fn,
yinit,
jnp.swapaxes(yaug, 0, 1)
)
And here’s the version without scan:
# Alternative version without looping:
def model_6_2_3_v2(nz=150, predict=False, yobs=None):
# Augment dataset:
C, T = yobs.shape if yobs is not None else (100, 5)
M = C + nz
# Augment data by adding rows of zeros:
yaug = jnp.vstack([yobs, jnp.zeros(shape=(nz, T))]) if yobs is not None else None
# Array with previous period catches:
yprev = jnp.hstack([jnp.zeros((M,1)), yaug[:, :-1]])
# Priors:
omega = npr.sample("omega", dist.Beta(3, 3))
p = npr.sample("p", dist.Beta(3, 3))
c = npr.sample("c", dist.Beta(3, 3))
plate_animals = npr.plate("Animals", M, dim=-2)
plate_time = npr.plate("Time", T, dim=-1)
# Determine which animal is part of the population:
with plate_animals:
z = npr.sample("z", dist.Bernoulli(omega), infer={"enumerate": "parallel"})
with plate_animals, plate_time:
p_eff = z * ((1 - yprev) * p + yprev * c)
npr.sample("yaug", dist.Bernoulli(p_eff), obs=yaug)
I simulate some data and run NUTS on each model like this:
import numpyro as npr
import numpyro.distributions as dist
from numpyro.infer import Predictive, NUTS, HMC, MCMC, SVI
from numpyro.contrib.control_flow import scan, cond
from numpyro import handlers
import jax
import jax.numpy as jnp
import jax.random as random
import jax.lax as lax
import numpy as np
from collections import namedtuple
# Define function to simulate data under Mb
def data_fn(N=200, T=5, p=0.3, c=0.4):
# First capture occasion:
yfull = np.zeros((N, T))
yfull[:,0] = np.random.binomial(n=1, p=p, size=N)
# Later capture occasions:
for t in range(1, T):
p_eff = (1 - yfull[:,t-1]) * p + yfull[:, t-1] * c
yfull[:,t] = np.random.binomial(n=1, p=p_eff, size=N)
ever_detected = yfull.max(1)
C = np.sum(ever_detected)
yobs = yfull[ever_detected==1]
return dict(N=N, p=p, c=c, C=C, T=T, yfull=yfull, yobs=yobs)
data = data_fn(N=200, T=5, p=0.3, c=0.4)
# Utility function for inference:
def run_inference(model, rng_key, args, **kwargs):
if args.algo == "NUTS":
kernel = NUTS(model)
elif args.algo == "HMC":
kernel = HMC(model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, **kwargs)
mcmc.print_summary()
return mcmc.get_samples()
Args = namedtuple("Args", "algo num_warmup num_samples num_chains")
args = Args("NUTS", 500, 2000, 4)
# Create random key:
key = random.PRNGKey(0)
Inference on the scan model:
key, subkey = random.split(key)
posterior = run_inference(model_6_2_3, subkey, args, **{"nz": 150, "yobs": data["yobs"]})
# Returns wrong values (note in particular mean of c, which should be close to 0.40):
# mean std median 5.0% 95.0% n_eff r_hat
# c 0.79 0.05 0.79 0.70 0.88 4781.03 1.00
# omega 0.54 0.07 0.53 0.42 0.65 5128.32 1.00
# p 0.35 0.02 0.35 0.30 0.38 5357.59 1.00
# Number of divergences: 0
Inference on the vectorized model:
key, subkey = random.split(key)
posterior = run_inference(model_6_2_3_v2, subkey, args, **{"nz": 150, "yobs": data["yobs"]})
# Returns correct values:
# mean std median 5.0% 95.0% n_eff r_hat
# c 0.40 0.03 0.40 0.35 0.46 5190.41 1.00
# omega 0.62 0.04 0.62 0.55 0.69 5409.63 1.00
# p 0.29 0.03 0.29 0.25 0.33 4779.27 1.00
# Number of divergences: 0