Hello everyone again,
I’m working on some time series models based on @juanitorduz’s posts. The long term goal is extending his hierarchical multivariate ETS model, but to start we can deal with the univariate case.
I’m using the usual pattern of scan + handlers.condition. This works fine when I have a complete time series, but now I have missing data. I want to impute these points and not have them contribute to the log-likelihood, something I’ve done with mask in other projects.
Here’s the model, adapted from Juan’s post but with an added where statement to impute where the data are missing.
def level_model(y: Array, future: int = 0) -> None:
t_max = y.shape[0]
level_smoothing = numpyro.sample(
"level_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
noise = numpyro.sample("noise", dist.HalfNormal(scale=1))
def transition_fn(carry, t):
previous_level = carry
level = jnp.where(
t < t_max,
jnp.where(
jnp.isnan(y[t]),
previous_level, # keep previous level if observation is NaN
level_smoothing * y[t] + (1 - level_smoothing) * previous_level, # usual update
),
previous_level, # during forecasting period
)
mu = previous_level
pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise))
return level, pred
with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(
transition_fn,
level_init,
jnp.arange(t_max + future),
)
if future > 0:
numpyro.deterministic("y_forecast", preds[-future:])
The model works when the data are complete, but when any data are missing, sampling gives divergences at every iteration (and SVI cannot find initial params).
Thing’s I have tried:
- Adding in
numpyro.handlers.mask. This still gives divergences.
obs_mask = ~jnp.isnan(y)
# extend mask to length jnp.arange(t_max + future)
forecast_mask = jnp.zeros(future, dtype=bool)
extended_mask = jnp.concatenate([obs_mask, forecast_mask])
with numpyro.handlers.condition(data={"pred": y}):
with numpyro.handlers.mask(mask=extended_mask):
_, preds = scan(
transition_fn,
level_init,
jnp.arange(t_max + future),
)
- In the sample statement, masking using
obs_mask. Still divergences.
pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise), obs_mask=~jnp.isnan(y[t]))
I’m interested to know if anyone has found a solution to this, or if there’s another way of going about this problem.
Cheers,
Theo
Somewhat related: Multioutput Kalman filter with scan and missing observations - #2 by fehiepsi and Calculating log_likelihood for model with scan - #3 by julianstastny
Full reproducible script
import time
import jax
from jax import random, Array
import jax.numpy as jnp
import numpyro
from numpyro.contrib.control_flow import scan
from numpyro.infer import Predictive
import numpyro.distributions as dist
import numpy as np
def level_model(y: Array, future: int = 0) -> None:
t_max = y.shape[0]
level_smoothing = numpyro.sample(
"level_smoothing", dist.Beta(concentration1=1, concentration0=1)
)
level_init = numpyro.sample("level_init", dist.Normal(loc=0, scale=1))
noise = numpyro.sample("noise", dist.HalfNormal(scale=1))
def transition_fn(carry, t):
previous_level = carry
level = jnp.where(
t < t_max,
jnp.where(
jnp.isnan(y[t]),
previous_level, # keep previous level if observation is NaN
level_smoothing * y[t] + (1 - level_smoothing) * previous_level, # usual update
),
previous_level, # during forecasting period
)
mu = previous_level
pred = numpyro.sample("pred", dist.Normal(loc=mu, scale=noise), obs_mask=~jnp.isnan(y[t]))
return level, pred
obs_mask = ~jnp.isnan(y)
# extend mask to length jnp.arange(t_max + future)
forecast_mask = jnp.zeros(future, dtype=bool)
extended_mask = jnp.concatenate([obs_mask, forecast_mask])
with numpyro.handlers.condition(data={"pred": y}):
# with numpyro.handlers.mask(mask=extended_mask):
_, preds = scan(
transition_fn,
level_init,
jnp.arange(t_max + future),
)
if future > 0:
numpyro.deterministic("y_forecast", preds[-future:])
def run_inference(model, rng_key, y, future=0):
start = time.time()
sampler = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
sampler,
num_warmup=500,
num_samples=500,
num_chains=2,
progress_bar=True,
)
mcmc.run(rng_key, y=y, future=future)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc
def generate_forecasts(mcmc, rng_key, y, future_steps):
print(f"Generating {future_steps} step ahead forecasts...")
predictive = Predictive(
level_model,
posterior_samples=mcmc.get_samples(),
return_sites=["y_forecast"]
)
forecasts = predictive(rng_key, y=y, future=future_steps)
return forecasts
def main():
num_data = 100
future_steps = 20
rng_key = jax.random.PRNGKey(0)
t = jnp.arange(0, num_data)
y = jnp.sin(t * 0.1) + random.normal(rng_key, (num_data,)) * 0.2
# make nth element NaN
y = y.at[2].set(np.nan)
print(f"Generated {num_data} training points, forecasting {future_steps} steps ahead")
print(y)
obs_mask = ~jnp.isnan(y)
# extend mask to length jnp.arange(t_max + future)
forecast_mask = jnp.zeros(future_steps, dtype=bool)
extended_mask = jnp.concatenate([obs_mask, forecast_mask])
print(extended_mask)
# run inference
rng_key, rng_subkey = jax.random.split(rng_key)
mcmc = run_inference(level_model, rng_subkey, y, future=0)
# generate forecasts
rng_key, rng_subkey = jax.random.split(rng_key)
forecasts = generate_forecasts(mcmc, rng_subkey, y, future_steps)
forecast_samples = forecasts["y_forecast"]
forecast_mean = np.mean(forecast_samples, axis=0)
forecast_std = np.std(forecast_samples, axis=0)
print(f"\nForecast Summary:")
print(f"Mean forecast values: {forecast_mean[:5]}... (showing first 5)")
print(f"Forecast std dev: {forecast_std[:5]}... (showing first 5)")
print(f"Average forecast uncertainty: {np.mean(forecast_std):.3f}")
if __name__ == "__main__":
numpyro.set_platform("cpu")
numpyro.set_host_device_count(2)
main()