Calculating log_likelihood for model with scan


I have a time series model which uses scan, and I would like to do leave-one-out cross validation with it, first by running arviz’s loo-function and then by refitting the model for the few observations where the pareto-k values are too high. (In the future I’d like to do LFO-CV too)

For LOO-CV I have to compute the log-likelihood of the left-out datapoint y_i, given the model fitted with y_{-i}. And since order matters for my model, I considered as an obvious idea that I’d fit the model with y where y_i is set to None. But it isn’t possible to have None as an element in a numpy array, so this doesn’t work.

Any idea what else I can do? I think there should probably be a really easy and obvious solution that I have somehow missed.

I don’t know about the details of LOO-CV or LFO-CV so I might misunderstand your question. By fitting with y_{-i}, do you mean to fit the model with a shorter time length or with the same time length but y_i is a latent variable? You can just use scan for the former. For the latter, you can put an improper prior for y_i, replace the original y with that latent variable y_i, then using scan.

Yeah, treating y_i as a latent variable is what I want to do. But I can’t figure out how to do it in practice with scan. Maybe my issue becomes clearer if I say what I tried:

  1. Given some model
def model(X, y=None):
    # Model where y_i depends on x_i and y_{i-1}
    def transition(y_prev, data_curr):
        x_curr, y_curr = data_curr
        probs = foo(y_prev, x_curr)
        obs = numpyro.sample(
            "y", dist.Bernoulli(probs=probs, obs=y_curr
        return obs, (obs)
    _, (obs) = scan(transition, (init_y), (X, y), length=len(y))

my first idea was that I can fit model with y such that y[i] = None. For that, y can’t be a numpy array but has to be a list as far as I can tell. But it seems like numpyro or scan can’t handle y to be a Python list.

  1. I then thought that I can keep y as a numpy array but set y[i] = -1, then have a control flow inside scan which essentially says y_curr = None if y_curr == -1. But it seems like regular Python control flow isn’t allowed inside scan, and
y_curr = cond(
    y_curr != -1,
    lambda _: numpyro.deterministic('y_curr', y_curr),
    lambda _: numpyro.deterministic('y_curr', None),

also doesn’t seem to work.

I guess there ought to be a simple and embarrassing solution, but I can’t think of any.

Maybe I also don’t understand what you mean with putting an improper prior on y_i.

[Edit: Changed example model to reflect the issue outlined in the comment below]

Related: I looked at the Time Series Forecasting example and tried to explore doing it via conditioning with the condition effect handler as done in the example. But this seems to not work if scan uses obs as a carry, i.e. if y_t directly depends on y_{t-1}.

If there doesn’t exist an alternative way or some easy way around this (the example says it has to be done this way in order to forecast) this seems like a big problem for autoregressive time series models.

Here would be my approach (just heuristic)

y_latents = []
for i in range(sum(y == -1)):
    y_latents.append(numpyro.sample(f"y_null_{i}", dist.Bernoulli(0.5).mask(False)))
y = fill_null_by_latents(y, y_latents)
# then run scan as if all y is observed

Your approaches do not work because scan/cond requires the shapes to be static to be able to compile and run fast. I’m not sure what you want by using deterministic - if you want something like cond(y_curr != -1, lambda _: y_curr, lambda _: None, None) then it will not work because shapes of the output of true branch and false branch are different. You can replace them by Python control flows (for loop and if/else).

1 Like

Great! This does indeed work. Two related questions have come up for me:

  1. While implementing your solution I once accidentally fitted the model with y[i] = -1 but without your proposed change to the model. Surprisingly, this worked without warnings or errors even though this means that at some point obs = numpyro.sample("y", dist.Bernoulli(probs=probs), obs=-1) must have been called. What happens there? Does it treat this as if obs is None? A shallow look at my results seems to suggest this.

  2. An alternative solution I considered before you answered was to use the obs_mask argument of sample(...), where I denote the left-out sample by a False in the obs_mask.

def model(X, y=None, obs_mask=None):
    # Model where y_i depends on x_i and y_{i-1}
    def transition(y_prev, data_curr):
        x_curr, y_curr, obs_mask_curr = data_curr
        probs = foo(y_prev, x_curr)
        obs = numpyro.sample(
            "y", dist.Bernoulli(probs=probs), obs=y_curr, obs_mask=obs_mask_curr
        return obs, (obs)
    _, (obs) = scan(transition, (init_y), (X, y, obs_mask), length=len(y))

But this doesn’t seem to work. Can you say what my mistake here is?

Thanks so much!

Update regarding the second point: I now realized that in order for this to work, obs_mask_curr has to be a numpy array. (Sorry, I should have noticed this easily with a better look at the documentation)

But I keep having issues with shapes, so I wonder whether you can quickly confirm whether resolving the shape issues is even possible, since the error arises here, in numpyro/contrib/control_flow/

    176             # FIXME: is this rigorous?
    177             new_carry = tree_multimap(
--> 178                 lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry
    179             )
    180         return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

y[i] = -1 … worked without warnings or errors

I think you can enable validation numpyro.enable_validation() to catch this issue. The issue is the implementation can work for out-of-support values, e.g. we might expect the indentity function lambda x: x work for real numbers but that implementation also works for complex numbers.

Masking out the probability of missing sites is not equivalent to marginalization. If you want to mask those sites, you can do:

y_filled = jnp.where(y == -1, 0, y)
def transition(...):
    probs = ...
    # use uniform prior
    probs = jnp.where(y_equal_minus_1, 0.5, probs)
    obs = numpyro.sample("y", dist.Bernoulli(probs), obs=y_filled_curr)

but I’m not sure that this approach is valid (e.g. issues might happen when we filled y by some value and the current probs makes that value to be unlikely). I would recommend using enumeration to marginalize out those missing sites or using DiscreteHMCGibbs to perform sampling on those missing sites.

1 Like

Thanks, that’s super helpful.

So using DiscreteHMCGibbs works in the sense that it doesn’t give me errors, using obs_mask as in comment #6. But I’m a bit confused by the behavior because when leaving out one observation of N, I would expect there to be one site y_unobserved, but instead there are N-1.
The documentation of numpyro.sample says

  • obs_mask ( numpy.ndarray ) – Optional boolean array mask of shape broadcastable with fn.batch_shape . If provided, events with mask=True will be conditioned on obs and remaining events will be imputed by sampling. This introduces a latent sample site named name + "_unobserved" which should be used by guides.

and I definitely made sure that obs_mask is True for all except one position. Do you know what might be happening here? Something that’s also weird here is that n_eff is sometimes negative at these sites according to print_summary().

For marginalizing via enumeration, I would require the enumeration strategy to be "sequential" instead of "parallel", since future decisions depend on past decisions. (I think that was also the reason for my shape errors before, since it tried to do parallel enumeration.) But this isn’t possible in NumPyro yet, right?

IIUC, obs_mask is useful for SVI. It is a bit tricky for me to think of a way to use it in your model.

If you use marginalization, it does not matter if future decisions depend on past decisions (basically, we evaluate all possible cases). If you don’t want to use marginalization, you can use DiscreteHMCGibbs. In principle, both ways use the same code as in my previous comment:

y_latents = []
for i in range(sum(y == -1)):
    y_latents.append(numpyro.sample(f"y_null_{i}", dist.Bernoulli(0.5).mask(False)))
y = fill_null_by_latents(y, y_latents)
# then run scan as if all y is observed

but looking back on this issue, I don’t think that scan works with global discrete latent variables yet (sorry, you’re right). So I recommend using DiscreteHMCGibbs. You can simplify the code to

y_latents = numpyro.sample("y_latents",
    dist.Bernoulli(0.5).expand([sum(y == -1)]).to_event().mask(False))
y = fill_null_by_latents(y, y_latents)
1 Like

Thanks again for your help, this seemed to work great, but now I stumbled upon a problem.
The way I implemented this was like the following:

for i in np.flatnonzero(y == -1):
    y[i] = numpyro.sample(f"y_null_{i}", dist.Bernoulli(0.5).mask(False))

This works if I set num_chains to 1, but now that I tried to do my experiments on a larger scale with more chains it doesn’t work anymore, I get a ValueError: setting an array element with a sequence. in the loop’s body.

Do you know how to solve this?

If y is an array then the assignment that you use won’t work using JAX. Please use index_update instead: see Bayesian Imputation — NumPyro documentation

Thanks, this works more or less - A weird failure mode is that when I run multiple chains but run them sequentially, I get an error after the first chain has finished. (AssertionError: Cannot detect any discrete latent variables in the model.)
This doesn’t happen when I run them in parallel, so for me personally I consider the problem solved, but maybe there is a bug in DiscreteHMCGibbs?

I think it is a bug. Could you make a github issue with reproducible code? :slight_smile:

Hi, sorry for the late reply, I got busy and forgot about it.

I have now tried to make a minimal example for reproduction, but I can’t reproduce the bug anymore.
(I also can’t reproduce it in my real codebase, but the code there has changed since I reported the issue here, and I am not sure which change has caused the problem to disappear.)

If I encounter the problem again I’ll try to extract a minimal example from that.