Pyro performs dramatically slower than PyMC3 with Normalizing Flows on stochastic-volatility model inference

If you use NumPyro, then there are two ways to make it fast:

  • Using the following predefined AR (please double-check the implementation, I just quickly sketch it following GaussianRandomWalk pattern - GRW can be seen as a special case of AR, where coef=1)
from jax import lax, random
import jax.numpy as np

import numpyro
from numpyro.distributions import Distribution, Normal, constraints
from numpyro.distributions.util import validate_sample


class AR(Distribution):
    arg_constraints = {'init_values': constraints.real_vector,
                       'coefs': constraints.real_vector,
                       'scale': constraints.positive,
                       'num_steps': constraints.positive_integer}
    support = constraints.real_vector
    reparametrized_params = ['scale']

    def __init__(self, init, coef, scale=1., num_steps=1, validate_args=None):
        assert np.shape(num_steps) == ()
        assert init.shape[-1] == coef.shape[-1]
        self.init = init
        self.coef = coef
        self.scale = scale
        self.num_steps = num_steps
        batch_shape = lax.broadcast_shapes(np.shape(init)[:-1],
                                           np.shape(coef)[:-1],
                                           np.shape(scale))
        event_shape = (num_steps,)
        super(AR, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        # just a fake sample method to get initial params for autoguide and mcmc
        eps = random.normal(key, sample_shape + self.batch_shape + self.event_shape)
        return np.expand_dims(self.scale, -1) * eps

    @validate_sample
    def log_prob(self, value):
        assert value.shape[-1] == self.num_steps
        batch_shape = lax.broadcast_shapes(self.batch_shape, np.shape(value)[:-1])
        init = np.broadcast_to(self.init, batch_shape + self.init.shape[-1:])
        value = np.broadcast_to(value, batch_shape + (self.num_steps,))
        x = np.concatenate([init, value], -1)
        x_reg = np.stack([x[i:self.num_steps + i] for i in range(self.coef.shape[-1])], -1)
        noise = value - (self.coef * x_reg).sum(-1)
        return Normal(0, np.expand_dims(self.scale, -1)).log_prob(noise).sum(-1)

where coef is rho in PyMC3 and init is initial values of AR. You can set Normal prior to init in a separate numpyro.sample statement (PyMC3 used flat prior for init by default, so AR will be improper distribution by default).

  • Use lax.scan with reparameterized AR (as in this topic).

Note that nf = 'scale*10-loc*10' is equivalent to numpyro.contrib.autoguide.AutoDiagonalNormal, so no need to use AutoBNAFNormal.

Edit: I just test, MCMC is pretty fast: >100it/s. SVI finishes in seconds.

def stoch_vol_model():
    phi = numpyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = numpyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = numpyro.sample("mu", dist.Normal(0, 10))
    h_init = numpyro.sample("h_init", dist.Normal(0, 1), sample_shape=(2,))
    N = len(returns)
    h_ar = numpyro.sample("h_ar", AR(h_init, np.stack([phi, mu * (1 - phi)]), sigma2, N - 2))
    h = np.concatenate([h_init, h_ar], -1)
    return numpyro.sample('y', dist.Normal(0., np.exp(h / 2.)), obs=returns)

Btw, I think that it is better to follow the stable example. Using latent AR distribution likes this is not scalable and I think with that high number of latent variables, the inferences would be unstable! But using AR as an observation site would be fine.