Predict with exponential smoothing

Looking at forecasting principles and practice, I’ve tried writing the simple exponential smoothing model in NumPyro:

def ses(arr):
    T = len(arr)
    alpha = numpyro.sample('alpha', dist.Beta(1, 1))
    sigma = numpyro.sample('sigma', dist.HalfNormal(1))
    level_0 = numpyro.sample('level_0', dist.Normal(0, 1))
    
    def transition(carry, y):
        y_1, level_1 = carry
        level = alpha*y_1 + (1-alpha)*level_1
        return (y, level), level

    _, levels = jax.lax.scan(transition, (arr[0], level_0), arr[1:])
        
    levels = numpyro.deterministic('levels', jnp.asarray(levels))
    
    numpyro.sample('obs', dist.Normal(levels, sigma), obs=arr[1:])

Then, from the forecast equation in the linked page, I can make predictions by taking the last value of levels.

I will then get the same prediction, with the same uncertainty, for all future timepoints. However, I would expect the uncertainty to grow as we go more steps into the future.

On this page, then have formulae for how to find the prediction interval, depending on the step, analytically.

However, they also write

For a few ETS models, there are no known formulas for prediction intervals. In these cases, the forecast() function uses simulated future sample paths and computes prediction intervals from the percentiles of these simulated future paths.

Does anyone know how they do that? Is it possible to simulate future sample paths using the above-defined model in NumPyro, in such a way that we’ll see the uncertainty grow as we forecast farther and farther into the future?

Have you looked at the predator-prey model example?

I don’t know to much about this kind of model (and how you handle the fact that the value you call arr ends at the end of the data) but it seems like they’re doing what you want!

2 Likes

Thanks @justinrporter ! Yes, this seems to work:

def ses(arr, future=0):
    T = arr.shape[0]
    alpha = numpyro.sample("alpha", dist.Beta(1, 1))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    level_init = arr[0]

    def transition_fn(carry, t):
        (level_1,) = carry
        level = jnp.where(t < T, alpha * arr[t] + (1 - alpha) * level_1, level_1)
        mu = level_1
        pred = numpyro.sample("pred", dist.Normal(mu, sigma))
        return (level,), pred

    with numpyro.handlers.condition(data={"pred": arr}):
        _, preds = scan(transition_fn, (level_init,), jnp.arange(T + future))

    if future > 0:
        numpyro.deterministic("y_forecast", preds)
1 Like