I’m trying to implement ARMA(1, 1) in NumPyro. Here’s what I have so far:
def arma(y):
alpha = numpyro.sample(
'alpha',
numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
)
phi = numpyro.sample(
'phi',
numpyro.distributions.TruncatedDistribution(numpyro.distributions.Normal(0, 1), low=-1, high=1),
)
mu = numpyro.sample(
'mu',
numpyro.distributions.Normal(0, 1),
)
sigma = numpyro.sample(
'sigma',
numpyro.distributions.HalfNormal(1),
)
logp = 0
err = [y[0] - (mu + phi*mu)]
for t in range(1, len(y)):
err.append(y[t] - (alpha*y[t-1] + phi*err[-1]))
numpyro.factor(
'factor',
numpyro.distributions.Normal(0, sigma).log_prob(jax.numpy.array(err)),
)
(where y
is my array).
The docs for HMM suggest to use lax.scan
instead of a for-loop, but the example there seems very far removed from the one I have.
Any suggestions on how to use lax.scan
in the above model?