Why does scan give the wrong result here?

Short version of the problem

I have two implementations of a simple model. The one using numpyro’s scan gives wrong results and I would like to know what I’m doing wrong.

Detailed explanation with MWE

I’m working on a Capture-Recapture-type model where the probability of being captured on each of the T capture occasions depends on whether the animal was captured in the previous occasion.

P(y_{it} = 1) = \begin{cases} p \quad \text{if } y_{it-1} = 0,\\ c \quad \text{if } y_{it-1} = 1. \end{cases}

The observed data looks like this for T=5:

Y = \begin{bmatrix} 1, 0, 0, 1, 0\\ 0, 0, 1, 0, 0\\ 1, 1, 1, 0, 0\\ \vdots \end{bmatrix}_{C \times T}

where each row is the capture history of each animal that was ever captured during the T capture occasions.

The model uses data augmentation – effectively appending rows of zeros to Y – and then including a latent binary indicator z_i which is 1 if that animal is included in the original sample and 0 otherwise (the purpose of this is to do inference on N, the true size of the population).

I’ve implemented a version of this model using numpyro’s scan which gives wrong results and another version using vectorized operation which gives the correct results. I would like to understand what I’m doing wrong with scan.

Here’s the scan version:

# Model using loop (scan) over capture occasions:
def model_6_2_3(nz=150, predict=False, yobs=None):
    C, T = yobs.shape if yobs is not None else (100, 5)
    M = C + nz
    # Augment data by adding rows of zeros:
    yaug = jnp.vstack([yobs, jnp.zeros(shape=(nz, T))]) if yobs is not None else None

    # Priors:
    omega = npr.sample("omega", dist.Beta(3, 3)) # Probability of inclusion, P(z=1).
    p = npr.sample("p", dist.Beta(3, 3)) # Prob. of capture if not captured last period.
    c = npr.sample("c", dist.Beta(3, 3)) # Prob. of capture if captured last period.

    plate_animals = npr.plate("Animals", M, dim=-1)

    # Determine which animal is part of the population:
    with plate_animals:
        z = npr.sample("z", dist.Bernoulli(omega), infer={"enumerate": "parallel"})

    # The transition function will be applied to each column of yaug.
    def transition_fn(yprev, y):
        with plate_animals:
            p_eff = z * ((1 - yprev) * p + yprev * c)
            npr.sample("yaug", dist.Bernoulli(p_eff), obs=y)
        return y, None
    
    yinit = jnp.zeros((M,))
    scan(
        transition_fn,
        yinit,
        jnp.swapaxes(yaug, 0, 1)
    )

And here’s the version without scan:

# Alternative version without looping:
def model_6_2_3_v2(nz=150, predict=False, yobs=None):
    # Augment dataset:
    C, T = yobs.shape if yobs is not None else (100, 5)
    M = C + nz
    # Augment data by adding rows of zeros:
    yaug = jnp.vstack([yobs, jnp.zeros(shape=(nz, T))]) if yobs is not None else None

    # Array with previous period catches:
    yprev = jnp.hstack([jnp.zeros((M,1)), yaug[:, :-1]])

    # Priors:
    omega = npr.sample("omega", dist.Beta(3, 3))
    p = npr.sample("p", dist.Beta(3, 3))
    c = npr.sample("c", dist.Beta(3, 3))

    plate_animals = npr.plate("Animals", M, dim=-2)
    plate_time = npr.plate("Time", T, dim=-1)

    # Determine which animal is part of the population:
    with plate_animals:
        z = npr.sample("z", dist.Bernoulli(omega), infer={"enumerate": "parallel"})
    with plate_animals, plate_time:
        p_eff = z * ((1 - yprev) * p + yprev * c)
        npr.sample("yaug", dist.Bernoulli(p_eff), obs=yaug)

I simulate some data and run NUTS on each model like this:

import numpyro as npr
import numpyro.distributions as dist
from numpyro.infer import Predictive, NUTS, HMC, MCMC, SVI
from numpyro.contrib.control_flow import scan, cond
from numpyro import handlers
import jax
import jax.numpy as jnp
import jax.random as random
import jax.lax as lax
import numpy as np
from collections import namedtuple

# Define function to simulate data under Mb
def data_fn(N=200, T=5, p=0.3, c=0.4):
    # First capture occasion:
    yfull = np.zeros((N, T))
    yfull[:,0] = np.random.binomial(n=1, p=p, size=N)
    # Later capture occasions:
    for t in range(1, T):
        p_eff = (1 - yfull[:,t-1]) * p + yfull[:, t-1] * c
        yfull[:,t] = np.random.binomial(n=1, p=p_eff, size=N)
    ever_detected = yfull.max(1)
    C = np.sum(ever_detected)
    yobs = yfull[ever_detected==1]
    return dict(N=N, p=p, c=c, C=C, T=T, yfull=yfull, yobs=yobs)

data = data_fn(N=200, T=5, p=0.3, c=0.4)

# Utility function for inference:
def run_inference(model, rng_key, args, **kwargs):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, **kwargs)
    mcmc.print_summary()
    return mcmc.get_samples()

Args = namedtuple("Args", "algo num_warmup num_samples num_chains")
args = Args("NUTS", 500, 2000, 4)

# Create random key:
key = random.PRNGKey(0)

Inference on the scan model:

key, subkey = random.split(key)
posterior = run_inference(model_6_2_3, subkey, args, **{"nz": 150, "yobs": data["yobs"]})

# Returns wrong values (note in particular mean of c, which should be close to 0.40):
#                 mean       std    median      5.0%     95.0%     n_eff     r_hat
#          c      0.79      0.05      0.79      0.70      0.88   4781.03      1.00
#      omega      0.54      0.07      0.53      0.42      0.65   5128.32      1.00
#          p      0.35      0.02      0.35      0.30      0.38   5357.59      1.00

# Number of divergences: 0

Inference on the vectorized model:

key, subkey = random.split(key)
posterior = run_inference(model_6_2_3_v2, subkey, args, **{"nz": 150, "yobs": data["yobs"]})

# Returns correct values:
#                 mean       std    median      5.0%     95.0%     n_eff     r_hat
#          c      0.40      0.03      0.40      0.35      0.46   5190.41      1.00
#      omega      0.62      0.04      0.62      0.55      0.69   5409.63      1.00
#          p      0.29      0.03      0.29      0.25      0.33   4779.27      1.00

# Number of divergences: 0

Could you use this log_density utility to see if you can get the same log density given some c, omega, p values? Then you might look at the trace output of log_density(...) to see what’s the difference.

Thanks for the suggestion. I’m getting an error, probably because I’m not using the function log_density properly.

This is what I’m doing:

from numpyro.contrib.funsor.infer_util import log_density

log_density(
    model=model_6_2_3,  # Get the same error with the vectorized model.
    model_args=(),
    model_kwargs={"yobs": data["yobs"]},
    params={"c": 0.4, "p": 0.3, "omega": 0.5})

This is the stack trace:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/var/folders/2y/fh6x5yb519xfpyzyx6yrc9g00000gn/T/ipykernel_41048/2430114414.py in <module>
      3     model_args=(),
      4     model_kwargs={"yobs": data["yobs"]},
----> 5     params={"c": 0.4, "p": 0.3, "omega": 0.5})

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    269     """
    270     result, model_trace, _ = _enum_log_density(
--> 271         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272     )
    273     return result.data, model_trace

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    157     model = substitute(model, data=params)
    158     with plate_to_enum_plate():
--> 159         model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    160     log_factors = []
    161     time_to_factors = defaultdict(list)  # log prob factors

~/anaconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    169         :return: `OrderedDict` containing the execution trace.
    170         """
--> 171         self(*args, **kwargs)
    172         return self.trace
    173 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

/var/folders/2y/fh6x5yb519xfpyzyx6yrc9g00000gn/T/ipykernel_41048/2983584856.py in model_6_2_3(nz, predict, yobs)
     14 
     15     # Determine which animal is part of the population:
---> 16     with plate_animals:
     17         z = npr.sample("z", dist.Bernoulli(omega), infer={"enumerate": "parallel"})
     18 

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in __enter__(self)
    505         )
    506         indices = to_data(
--> 507             self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE
    508         )
    509         # extract the dimension allocated by to_data to match plate's current behavior

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in to_data(x, name_to_dim, dim_type)
    707     }
    708 
--> 709     msg = apply_stack(initial_msg)
    710     return msg["value"]

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in apply_stack(msg)
     45     pointer = 0
     46     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47         handler.process_message(msg)
     48         # When a Messenger sets the "stop" field of a message,
     49         # it prevents any Messengers above it on the stack from being applied.

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in process_message(self, msg)
    521     def process_message(self, msg):
    522         if msg["type"] in ["to_funsor", "to_data"]:
--> 523             return super().process_message(msg)
    524         return OrigPlateMessenger.process_message(self, msg)
    525 

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in process_message(self, msg)
    234             self._pyro_to_funsor(msg)
    235         elif msg["type"] == "to_data":
--> 236             self._pyro_to_data(msg)
    237 
    238     @staticmethod

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in _pyro_to_data(cls, msg)
    264         name_to_dim.update(
    265             cls._get_name_to_dim(
--> 266                 batch_names, name_to_dim=name_to_dim, dim_type=dim_type
    267             )
    268         )

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in _get_name_to_dim(batch_names, name_to_dim, dim_type)
    249         # read dimensions and allocate fresh dimensions as necessary
    250         for name, dim_request in name_to_dim.items():
--> 251             name_to_dim[name] = _DIM_STACK.request(name, dim_request)[1]

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in request(self, name, dim)
    161         if not found:
    162             name, dim = self._gendim(
--> 163                 NameRequest(name, dim_type), DimRequest(dim, dim_type)
    164             )
    165 

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in _gendim(self, name_request, dim_request)
    132             or (dim_type == DimType.VISIBLE and fresh_dim <= self._first_available_dim)
    133         ):
--> 134             raise ValueError(f"Ran out of free dims during allocation for {fresh_name}")
    135 
    136         return fresh_name, fresh_dim

ValueError: Ran out of free dims during allocation for Animals

Maybe you need enum(model), as in the log_density docstring?

I get the same error with enum.

log_density(
    model=enum(model_6_2_3),  # Same error as without enum
    # model=enum(config_enumerate(model_6_2_3)),  # <-- Get the same error, too.
    model_args=(),
    model_kwargs={"yobs": data["yobs"]},
    params={"c": 0.4, "p": 0.3, "omega": 0.5})

Sorry about this - currently, we need to specify first_available_dim

enum(model_6_2_3, first_available_dim=-2)

Thank you @fehiepsi, I can get log_density to run now. The returned log densities for the scan and the vectorized models are definitely different, as we expected. Looking at the trace values, nothing jumps out to me – but I also don’t quite understand some of the output there. I’m pasting the traces below in the hopes something is obviously wrong to you.

Could the fact that in the scan model (model_6_2_3) includes a plate outside the the scan’s function be the problem?

Here are the traces using log_density:

ld_scan, trace_scan = log_density(
    model=enum(model_6_2_3, first_available_dim=-2), 
    # model=enum(config_enumerate(model_6_2_3), first_available_dim=-2), 
    model_args=(),
    model_kwargs={"yobs": data["yobs"]},
    params={"c": 0.4, "p": 0.3, "omega": 0.5})

which produces trace_scan

OrderedDict(
    [('omega',
    {'type': 'sample',
    'name': 'omega',
    'fn': <numpyro.distributions.continuous.Beta at 0x7fbb299457d0>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': 0.5,
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [],
    'infer': {'dim_to_name': OrderedDict(), 'name_to_dim': {}}}),
    ('p',
    {'type': 'sample',
    'name': 'p',
    'fn': <numpyro.distributions.continuous.Beta at 0x7fbb01ac1c50>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': 0.3,
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [],
    'infer': {'dim_to_name': OrderedDict(), 'name_to_dim': {}}}),
    ('c',
    {'type': 'sample',
    'name': 'c',
    'fn': <numpyro.distributions.continuous.Beta at 0x7fbaea1b4bd0>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': 0.4,
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [],
    'infer': {'dim_to_name': OrderedDict(), 'name_to_dim': {}}}),
    ('z',
    {'type': 'sample',
    'name': 'z',
    'fn': <numpyro.distributions.distribution.ExpandedDistribution at 0x7fbb2993bc90>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': DeviceArray([[0],
                [1]], dtype=int32),
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [CondIndepStackFrame(name='Animals', dim=-1, size=312)],
    'infer': {'enumerate': 'parallel',
    'dim_to_name': OrderedDict([(-2, 'z'), (-1, 'Animals')]),
    'name_to_dim': {'z': -2, 'Animals': -1}},
    'done': True}),
    ('_PREV_yaug',
    {'type': 'sample',
    'name': '_PREV_yaug',
    'fn': <numpyro.distributions.discrete.BernoulliProbs at 0x7fbaea147c10>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': DeviceArray([0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0.,
                0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0.,
                1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
                0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0.,
                0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1.,
                1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.,
                1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
                0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0.,
                0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.,
                1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 1.,
                0., 0., 0., 1., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],            dtype=float32),
    'scale': None,
    'is_observed': True,
    'intermediates': [],
    'cond_indep_stack': [CondIndepStackFrame(name='Animals', dim=-1, size=312)],
    'infer': {'_scan_current_index': 0,
    'dim_to_name': OrderedDict([(-2, 'z'), (-1, 'Animals')]),
    'name_to_dim': {'z': -2, 'Animals': -1}}}),
    ('yaug',
    {'args': (),
    'fn': <numpyro.distributions.discrete.BernoulliProbs at 0x7fbb1a690e90>,
    'intermediates': [],
    'value': DeviceArray([[[0., 0., 0., ..., 0., 0., 0.]],
    
                [[1., 0., 0., ..., 0., 0., 0.]],
    
                [[1., 1., 1., ..., 0., 0., 0.]],
    
                [[1., 0., 1., ..., 0., 0., 0.]]], dtype=float32),
    '_control_flow_done': True,
    'type': 'sample',
    'name': 'yaug',
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'scale': None,
    'is_observed': True,
    'cond_indep_stack': [CondIndepStackFrame(name='Animals', dim=-1, size=312)],
    'infer': {'_scan_current_index': Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,
    'dim_to_name': OrderedDict([(-3, '_time_yaug'),
                    (-2, 'z'),
                    (-1, 'Animals')]),
    'name_to_dim': {'_time_yaug': -3, 'z': -2, 'Animals': -1}}})])

And for the vectorized model,

ld_vec, trace_vec = log_density(
    model=enum(model_6_2_3_v2, first_available_dim=-3), 
    # model=enum(config_enumerate(model_6_2_3), first_available_dim=-2), 
    model_args=(),
    model_kwargs={"yobs": data["yobs"]},
    params={"c": 0.4, "p": 0.3, "omega": 0.5})

which produces trace_vec

OrderedDict(
    [('omega',
    {'type': 'sample',
    'name': 'omega',
    'fn': <numpyro.distributions.continuous.Beta at 0x7fbb298b5890>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': 0.5,
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [],
    'infer': {'dim_to_name': OrderedDict(), 'name_to_dim': {}}}),
    ('p',
    {'type': 'sample',
    'name': 'p',
    'fn': <numpyro.distributions.continuous.Beta at 0x7fbb1a59b3d0>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': 0.3,
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [],
    'infer': {'dim_to_name': OrderedDict(), 'name_to_dim': {}}}),
    ('c',
    {'type': 'sample',
    'name': 'c',
    'fn': <numpyro.distributions.continuous.Beta at 0x7fbb2993b290>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': 0.4,
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [],
    'infer': {'dim_to_name': OrderedDict(), 'name_to_dim': {}}}),
    ('z',
    {'type': 'sample',
    'name': 'z',
    'fn': <numpyro.distributions.distribution.ExpandedDistribution at 0x7fbb39409510>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': DeviceArray([[[0]],
    
                [[1]]], dtype=int32),
    'scale': None,
    'is_observed': False,
    'intermediates': [],
    'cond_indep_stack': [CondIndepStackFrame(name='Animals', dim=-2, size=312)],
    'infer': {'enumerate': 'parallel',
    'dim_to_name': OrderedDict([(-3, 'z'), (-2, 'Animals')]),
    'name_to_dim': {'z': -3, 'Animals': -2}},
    'done': True}),
    ('yaug',
    {'type': 'sample',
    'name': 'yaug',
    'fn': <numpyro.distributions.discrete.BernoulliProbs at 0x7fbb298ec810>,
    'args': (),
    'kwargs': {'rng_key': None, 'sample_shape': ()},
    'value': DeviceArray([[0., 0., 1., 1., 1.],
                [0., 0., 0., 1., 0.],
                [0., 0., 0., 1., 1.],
                ...,
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.]], dtype=float32),
    'scale': None,
    'is_observed': True,
    'intermediates': [],
    'cond_indep_stack': [CondIndepStackFrame(name='Time', dim=-1, size=5),
    CondIndepStackFrame(name='Animals', dim=-2, size=312)],
    'infer': {'dim_to_name': OrderedDict([(-3, 'z'),
                    (-2, 'Animals'),
                    (-1, 'Time')]),
    'name_to_dim': {'z': -3, 'Animals': -2, 'Time': -1}}})])

Thanks @Maturin! Turns out that this is a bug, could you report this on Github? enumeration requires all batch dimensions declared using plate. scan creates a batch time dimension but we don’t have plate for it.

Great! Just opened the following issue on Github #1442.

1 Like