Scan with missing data: combining handlers.condition with handlers.mask

Hello everyone again,

I’m working on some time series models based on @juanitorduz’s posts. The long term goal is extending his hierarchical multivariate ETS model, but to start we can deal with the univariate case.

I’m using the usual pattern of scan + handlers.condition. This works fine when I have a complete time series, but now I have missing data. I want to impute these points and not have them contribute to the log-likelihood, something I’ve done with mask in other projects.

Here’s the model, adapted from Juan’s post but with an added where statement to impute where the data are missing.

def level_model(y: Array, future: int = 0) -> None:
    t_max = y.shape[0]

    level_smoothing = numpyro.sample(
        "level_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
    noise = numpyro.sample("noise", dist.HalfNormal(scale=1))

    def transition_fn(carry, t):
        previous_level = carry

        level = jnp.where(
            t < t_max,
            jnp.where(
                jnp.isnan(y[t]),
                previous_level,  # keep previous level if observation is NaN
                level_smoothing * y[t] + (1 - level_smoothing) * previous_level,  # usual update
            ),
            previous_level,  # during forecasting period
        )

        mu = previous_level
        pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))

        return level, pred


    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(
            transition_fn,
            level_init,
            jnp.arange(t_max + future),
        )

    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])

The model works when the data are complete, but when any data are missing, sampling gives divergences at every iteration (and SVI cannot find initial params).

Thing’s I have tried:

  1. Adding in numpyro.handlers.mask. This still gives divergences.
    obs_mask = ~jnp.isnan(y)
    # extend mask to length jnp.arange(t_max + future)
    forecast_mask = jnp.zeros(future, dtype=bool)
    extended_mask = jnp.concatenate([obs_mask, forecast_mask])

    with numpyro.handlers.condition(data={"pred": y}):
        with numpyro.handlers.mask(mask=extended_mask):
            _, preds = scan(
                transition_fn,
                level_init,
                jnp.arange(t_max + future),
            )
  1. In the sample statement, masking using obs_mask. Still divergences.
pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise), obs_mask=~jnp.isnan(y[t]))

I’m interested to know if anyone has found a solution to this, or if there’s another way of going about this problem.

Cheers,
Theo

Somewhat related: Multioutput Kalman filter with scan and missing observations - #2 by fehiepsi and Calculating log_likelihood for model with scan - #3 by julianstastny

Full reproducible script

import time

import jax
from jax import random, Array
import jax.numpy as jnp
import numpyro
from numpyro.contrib.control_flow import scan
from numpyro.infer import Predictive
import numpyro.distributions as dist
import numpy as np


def level_model(y: Array, future: int = 0) -> None:
    t_max = y.shape[0]

    level_smoothing = numpyro.sample(
        "level_smoothing", dist.Beta(concentration1=1, concentration0=1)
    )
    level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
    noise = numpyro.sample("noise", dist.HalfNormal(scale=1))

    def transition_fn(carry, t):
        previous_level = carry

        level = jnp.where(
            t < t_max,
            jnp.where(
                jnp.isnan(y[t]),
                previous_level,  # keep previous level if observation is NaN
                level_smoothing * y[t] + (1 - level_smoothing) * previous_level,  # usual update
            ),
            previous_level,  # during forecasting period
        )

        mu = previous_level
        pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise), obs_mask=~jnp.isnan(y[t]))

        return level, pred

    obs_mask = ~jnp.isnan(y)
    # extend mask to length jnp.arange(t_max + future)
    forecast_mask = jnp.zeros(future, dtype=bool)
    extended_mask = jnp.concatenate([obs_mask, forecast_mask])


    with numpyro.handlers.condition(data={"pred": y}):
        # with numpyro.handlers.mask(mask=extended_mask):
        _, preds = scan(
            transition_fn,
            level_init,
            jnp.arange(t_max + future),
        )

    if future > 0:
        numpyro.deterministic("y_forecast", preds[-future:])


def run_inference(model, rng_key, y, future=0):
    start = time.time()
    sampler = numpyro.infer.NUTS(model)
    mcmc = numpyro.infer.MCMC(
        sampler,
        num_warmup=500,
        num_samples=500,
        num_chains=2,
        progress_bar=True,
    )
    mcmc.run(rng_key, y=y, future=future)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc


def generate_forecasts(mcmc, rng_key, y, future_steps):
    print(f"Generating {future_steps} step ahead forecasts...")
    
    predictive = Predictive(
        level_model,
        posterior_samples=mcmc.get_samples(),
        return_sites=["y_forecast"]
    )
    
    forecasts = predictive(rng_key, y=y, future=future_steps)
    return forecasts


def main():
    num_data = 100
    future_steps = 20
    
    rng_key = jax.random.PRNGKey(0)
    t = jnp.arange(0, num_data)
    y = jnp.sin(t * 0.1) + random.normal(rng_key, (num_data,)) * 0.2

    # make nth element NaN
    y = y.at[2].set(np.nan)
    
    print(f"Generated {num_data} training points, forecasting {future_steps} steps ahead")
    print(y)

    obs_mask = ~jnp.isnan(y)
    # extend mask to length jnp.arange(t_max + future)
    forecast_mask = jnp.zeros(future_steps, dtype=bool)
    extended_mask = jnp.concatenate([obs_mask, forecast_mask])
    print(extended_mask)
    
    # run inference
    rng_key, rng_subkey = jax.random.split(rng_key)
    mcmc = run_inference(level_model, rng_subkey, y, future=0)
    
    # generate forecasts
    rng_key, rng_subkey = jax.random.split(rng_key)
    forecasts = generate_forecasts(mcmc, rng_subkey, y, future_steps)

    forecast_samples = forecasts["y_forecast"]
    forecast_mean = np.mean(forecast_samples, axis=0)
    forecast_std = np.std(forecast_samples, axis=0)
    
    print(f"\nForecast Summary:")
    print(f"Mean forecast values: {forecast_mean[:5]}... (showing first 5)")
    print(f"Forecast std dev: {forecast_std[:5]}... (showing first 5)")
    print(f"Average forecast uncertainty: {np.mean(forecast_std):.3f}")


if __name__ == "__main__":
    numpyro.set_platform("cpu")
    numpyro.set_host_device_count(2)

    main()