If you use NumPyro, then there are two ways to make it fast:
- Using the following predefined AR (please double-check the implementation, I just quickly sketch it following GaussianRandomWalk pattern - GRW can be seen as a special case of AR, where
coef=1
)
from jax import lax, random
import jax.numpy as np
import numpyro
from numpyro.distributions import Distribution, Normal, constraints
from numpyro.distributions.util import validate_sample
class AR(Distribution):
arg_constraints = {'init_values': constraints.real_vector,
'coefs': constraints.real_vector,
'scale': constraints.positive,
'num_steps': constraints.positive_integer}
support = constraints.real_vector
reparametrized_params = ['scale']
def __init__(self, init, coef, scale=1., num_steps=1, validate_args=None):
assert np.shape(num_steps) == ()
assert init.shape[-1] == coef.shape[-1]
self.init = init
self.coef = coef
self.scale = scale
self.num_steps = num_steps
batch_shape = lax.broadcast_shapes(np.shape(init)[:-1],
np.shape(coef)[:-1],
np.shape(scale))
event_shape = (num_steps,)
super(AR, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
# just a fake sample method to get initial params for autoguide and mcmc
eps = random.normal(key, sample_shape + self.batch_shape + self.event_shape)
return np.expand_dims(self.scale, -1) * eps
@validate_sample
def log_prob(self, value):
assert value.shape[-1] == self.num_steps
batch_shape = lax.broadcast_shapes(self.batch_shape, np.shape(value)[:-1])
init = np.broadcast_to(self.init, batch_shape + self.init.shape[-1:])
value = np.broadcast_to(value, batch_shape + (self.num_steps,))
x = np.concatenate([init, value], -1)
x_reg = np.stack([x[i:self.num_steps + i] for i in range(self.coef.shape[-1])], -1)
noise = value - (self.coef * x_reg).sum(-1)
return Normal(0, np.expand_dims(self.scale, -1)).log_prob(noise).sum(-1)
where coef
is rho
in PyMC3 and init
is initial values of AR. You can set Normal prior to init
in a separate numpyro.sample
statement (PyMC3 used flat prior for init
by default, so AR will be improper distribution by default).
- Use
lax.scan
with reparameterized AR (as in this topic).
Note that nf = 'scale*10-loc*10'
is equivalent to numpyro.contrib.autoguide.AutoDiagonalNormal
, so no need to use AutoBNAFNormal
.
Edit: I just test, MCMC is pretty fast: >100it/s. SVI finishes in seconds.
def stoch_vol_model():
phi = numpyro.sample("phi", dist.Beta(20, 1.5))
phi = 2 * phi - 1
sigma2 = numpyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
mu = numpyro.sample("mu", dist.Normal(0, 10))
h_init = numpyro.sample("h_init", dist.Normal(0, 1), sample_shape=(2,))
N = len(returns)
h_ar = numpyro.sample("h_ar", AR(h_init, np.stack([phi, mu * (1 - phi)]), sigma2, N - 2))
h = np.concatenate([h_init, h_ar], -1)
return numpyro.sample('y', dist.Normal(0., np.exp(h / 2.)), obs=returns)
Btw, I think that it is better to follow the stable example. Using latent AR distribution likes this is not scalable and I think with that high number of latent variables, the inferences would be unstable! But using AR as an observation site would be fine.