Scan with missing data: combining handlers.condition with handlers.mask

I managed to get this working.

For the statement

level = jnp.where(
jnp.isnan(y[t]),
previous_level,
level_smoothing * y[t] + (1 - level_smoothing) * previous_level,
)

Masking the likelihood was not enough. Apparently, during autodiff, the gradients flow through both arm the jnp.where, so NaN gradients from the dead branch ruin everything → divergences.

The solution is to replace NaN with some value (the value does not make a difference as long as it has the correct support of the process, I tried 1, 1e6 etc below and the output was the same). This keeps the jax computation graph NaN-free. The logic to freeze the state updates and missing positions now works. And then I used dist.mask to zero out the log-prob contribution at missing positions (rather than handlers.mask).

def level_model(y: Array, future: int = 0) → None:
    t_max = y.shape[0]

    _NAN_SENTINEL = 1.0  # any finite value
    nan_mask = jnp.isnan(y)
    y_clean = jnp.where(nan_mask, _NAN_SENTINEL, y)
    nan_mask_padded = jnp.concatenate([nan_mask, jnp.ones(future, dtype=bool)])

    level_smoothing = numpyro.sample(
        “level_smoothing”, dist.Beta(concentration1=1, concentration0=1)
    )
    level_init = numpyro.sample(“level_init”, dist.Normal(loc=0, scale=1))
    noise = numpyro.sample(“noise”, dist.HalfNormal(scale=1))

    def transition_fn(carry, t):
         previous_level = carry
         is_observed = (t < t_max) & ~nan_mask_padded[t]
         level_updated = (
            level_smoothing * y_clean[t]
            + (1 - level_smoothing) * previous_level
          )
         level = jnp.where(is_observed, level_updated, previous_level)

         mu = previous_level
         pred = numpyro.sample(
               "pred",
               dist.Normal(loc=mu, scale=noise).mask(is_observed),  # note mask
         )
         return level, pred
    with numpyro.handlers.condition(data={“pred”: y_clean}):  # use the y without nan
        _, preds = scan(
              transition_fn, level_init, jnp.arange(t_max + future),
        )

   if future > 0:
         numpyro.deterministic(“y_forecast”, preds[-future:])

The idea to keep the data as safe values is taken from Observation Masked . The sigma inflation idea also works (I think, done less testing), but I prefer the use of mask.

1 Like