 # Lax.scan to implement ARMA(1, 1)

I’m trying to implement ARMA(1, 1) in NumPyro. Here’s what I have so far:

``````def arma(y):
alpha = numpyro.sample(
'alpha',
numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
)
phi = numpyro.sample(
'phi',
numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
)
mu = numpyro.sample(
'mu',
numpyro.distributions.Normal(0, 1),
)

sigma = numpyro.sample(
'sigma',
numpyro.distributions.HalfNormal(1),
)

logp = 0
err = [y - (mu + phi*mu)]
for t in range(1, len(y)):
err.append(y[t] - (alpha*y[t-1] + phi*err[-1]))

numpyro.factor(
'factor',
numpyro.distributions.Normal(0, sigma).log_prob(jax.numpy.array(err)),
)
``````

(where `y` is my array).

The docs for HMM suggest to use `lax.scan` instead of a for-loop, but the example there seems very far removed from the one I have.

Any suggestions on how to use `lax.scan` in the above model?

@marcogorelli Could you sketch the relationship between variables in your model? My impression so far is `scan` can be used for this type of model.

Here is my sketch

``````def transition_fn(carry, y):
last_err, last_y = carry
err = y - (alpha * last_y + phi * last_err)
return (err, y), err

err_0 = y - (mu + phi*mu)
_, err = jax.lax.scan(transition_fn, (err_0, y), y[1:])
err = jnp.concatentate([err_0[None], err])
numpyro.sample("err", dist.Normal(0, sigma), obs=err)
``````

The forecasting tutorial will also be a good example for `scan`.

1 Like

Indeed, that works - thanks!

IMO this is a lot simpler than the SGT example from the docs - would an ARMA example notebook be a welcome addition to the docs?