Lax.scan to implement ARMA(1, 1)

I’m trying to implement ARMA(1, 1) in NumPyro. Here’s what I have so far:

def arma(y):
    alpha = numpyro.sample(
        'alpha',
        numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
    )
    phi = numpyro.sample(
        'phi',
        numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
    )
    mu = numpyro.sample(
        'mu',
        numpyro.distributions.Normal(0, 1),
    )
    
    sigma = numpyro.sample(
        'sigma',
        numpyro.distributions.HalfNormal(1),
    )
    
    logp = 0
    err = [y[0] - (mu + phi*mu)]
    for t in range(1, len(y)):
        err.append(y[t] - (alpha*y[t-1] + phi*err[-1]))
    
    numpyro.factor(
        'factor',
        numpyro.distributions.Normal(0, sigma).log_prob(jax.numpy.array(err)),
    )

(where y is my array).

The docs for HMM suggest to use lax.scan instead of a for-loop, but the example there seems very far removed from the one I have.

Any suggestions on how to use lax.scan in the above model?

@marcogorelli Could you sketch the relationship between variables in your model? My impression so far is scan can be used for this type of model.

Here is my sketch

def transition_fn(carry, y):
    last_err, last_y = carry
    err = y - (alpha * last_y + phi * last_err)
    return (err, y), err

err_0 = y[0] - (mu + phi*mu)
_, err = jax.lax.scan(transition_fn, (err_0, y[0]), y[1:])
err = jnp.concatentate([err_0[None], err])
numpyro.sample("err", dist.Normal(0, sigma), obs=err)

The forecasting tutorial will also be a good example for scan.

1 Like

Indeed, that works - thanks!

IMO this is a lot simpler than the SGT example from the docs - would an ARMA example notebook be a welcome addition to the docs?

Sure, please go ahead!