I’m trying to implement ARMA(1, 1) in NumPyro. Here’s what I have so far:

```
def arma(y):
alpha = numpyro.sample(
'alpha',
numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
)
phi = numpyro.sample(
'phi',
numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
)
mu = numpyro.sample(
'mu',
numpyro.distributions.Normal(0, 1),
)
sigma = numpyro.sample(
'sigma',
numpyro.distributions.HalfNormal(1),
)
logp = 0
err = [y[0] - (mu + phi*mu)]
for t in range(1, len(y)):
err.append(y[t] - (alpha*y[t-1] + phi*err[-1]))
numpyro.factor(
'factor',
numpyro.distributions.Normal(0, sigma).log_prob(jax.numpy.array(err)),
)
```

(where `y`

is my array).

The docs for HMM suggest to use `lax.scan`

instead of a for-loop, but the example there seems very far removed from the one I have.

Any suggestions on how to use `lax.scan`

in the above model?