Pyro performs dramatically slower than PyMC3 with Normalizing Flows on stochastic-volatility model inference

Hi guys,

Following the discussion on issue related to the similar topic, I realized that pyro takes much longer time for stochastic-volatility model inference, in particular, when comparing to pymc3.

I came across the following related discussion on pymc3 translation to pyro for a very similar model; however, it seems that the difference in computation times still persists and it is huge:

  • with pyro it is taking ages: even if I use the return time-series length of 100 observations, it took 3h:26m:10s in my case for 2,000 samples + 1,000 warm-up, and 50,000 steps for SVI). For time-series of 1,000 observations - it shows ~700sec/iter
  • with pymc3 it successfully finished for time-series with 1,000 observations in 1m:50s for 10,000 samples when I used normalizing-flow variational inference.

@fritzo suggested to vectorize the model. Could anybody assist on how this should be easily done?
Any other ideas on what causes such a slow performance and how to overcome it?

Below is the sample code for the model:

import argparse
import logging
from functools import partial
import torch
import pyro

from pyro import optim, poutine
import pyro.distributions as dist
from pyro.distributions.transforms import iterated, planar
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormalizingFlow
from pyro.infer.reparam import NeuTraReparam
from numpyro.examples.datasets import SP500, load_dataset

# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:101]

logging.basicConfig(format='%(message)s', level=logging.INFO)

def model(returns):
    
    phi = pyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = pyro.sample("mu", dist.Normal(0, 10))
    
    h = torch.empty(len(returns))
    for t in pyro.poutine.markov(range(len(returns))):
        if t == 0:
            h[t] = pyro.sample(f'h_{t}', dist.Normal(mu, sigma2**0.5/torch.sqrt(1. - phi * phi)).to_event(0))
        else:
            h[t] = pyro.sample(f'h_{t}', dist.Normal(mu + phi * (h[t-1] - mu), sigma2**0.5).to_event(0))

    y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)

# define aux fn to fit the guide
def fit_guide(guide, args):
    pyro.clear_param_store()
    adam = optim.Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, adam, Trace_ELBO())
    for i in range(args.num_steps):
        loss = svi.step(args.data)
        if i % 500 == 0:
            logging.info("[{}]Elbo loss = {:.2f}".format(i, loss))

# define aux fn to run HMC
def run_hmc(args, model, print_summary=False):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, warmup_steps=args.num_warmup, num_samples=args.num_samples)
    mcmc.run(args.data)
    if print_summary:
        mcmc.summary()
    return mcmc


def main(args):
    pyro.set_rng_seed(args.rng_seed)
    
    outDict = {}
    
    # If we want the Normalizing Flow
    # fit autoguide
    logging.info('\nFitting a BNAF autoguide ...')
    guide = AutoNormalizingFlow(model, partial(iterated, args.num_flows, planar))
    fit_guide(guide, args)

    # Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['phi_shared_latent']


    samples = neutra.transform_sample(zs)

    outDict['nf_neutra_mcmc'] = mcmc
    outDict['nf_neutra_samples'] = samples
    
    
    return outDict


if __name__ == '__main__':
    assert pyro.__version__.startswith('1.2.1')
    parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer')
    parser.add_argument('-n', '--num-steps', default=1000, type=int,
                        help='number of SVI steps')
    parser.add_argument('-lr', '--learning-rate', default=1e-4, type=float,
                        help='learning rate for the Adam optimizer')
    parser.add_argument('--rng-seed', default=123, type=int,
                        help='RNG seed')
    parser.add_argument('--num-warmup', default=1000, type=int,
                        help='number of warmup steps for NUTS')
    parser.add_argument('--num-samples', default=2000, type=int,
                        help='number of samples to be drawn from NUTS')
    parser.add_argument('--data', default=torch.Tensor(returns), type=float,
                        help='Time-series of returns')
    parser.add_argument('--num-flows', default=4, type=int,
                        help='number of flows in the BNAF autoguide')

    args = parser.parse_args(args=[])
    
    outDict = main(args)

Thanks,
Arturs

Hi Arturs,

I would recommend vectorizing this model using pyro.plate. We actually have a stochastic volatility example you could fork. Note for the exponential decay part mu + phi * (h[t-1] - mu) you can use the fast pyro.ops.tensor_utils.convolve() with an geometric kernel (1-phi) ** torch.arange(n).

Good luck,
Fritz

Without your PyMC3 code I can’t say for sure, but I don’t think the two times you’re comparing are measuring the same things, either in terms of model implementation (the PyMC example you linked to uses a single GaussianRandomWalk distribution rather than a Python for loop) or inference (it seems from your description like you’re comparing VI in PyMC to NeutraHMC in Pyro, and both the flows and optimization procedures may be different and have different hyperparameters?). Still, I agree the difference is too high.

As a general comment on performance: models in Pyro that perform many operations on many small tensors and use Python control flow heavily, like the version of the stochastic volatility model you’ve written here, have serious performance issues. As tensor sizes go up and the fraction of time spent during inference actually performing numerical computations like matrix multiplication goes up, these issues fade away.

Unfortunately in our experience these issues are largely reflections of overhead and performance issues inherent in the design of PyTorch Tensors, autograd, and jit, which are not currently optimized for large graphs like the ones in your time series model (see e.g. Tensor overhead, slow JIT compilation) or for the operations used heavily in certain inference algorithms (e.g. advanced indexing, einsum).

It’s usually possible to work around these issues by vectorizing your model as suggested by @fritzo, and in some cases our parallel-scan-based distribution implementations like GaussianHMM can even provide significant performance boosts for long time series. We also encourage affected users who can’t use these workarounds to try out JAX and NumPyro, which currently does a much better job in this regime thanks to XLA’s optimizations and which will approach inference feature parity with Pyro over time.

Some additional references if you want to try NumPyro: stochastic volatility example runs stochastic volatility with time-series length > 2400 and normalization flow example. They would be very fast. But the key point is vectorization, as @fritzo mentioned.

Thanks a lot guys for your prompt replies and different ideas. I saw the mentioned examples, but they use slightly different models (and in numpyro example @fehiepsi is mentioning it employs the predefined GaussianRandomWalk).
Will investigate the vectorization opportunity using plate. Will keep you posted here on the progress.
@fritzo, I guess you mean phi ** torch.arange(n) ?

@eb8680_2, please find the PyMC3 piece below, if it helps. It actually uses predefined AR distribution. I tried to implement it in pyro by inheriting dist.TorchDistribution and similarly defining log_prob. Haven’t really improved much. So probably the first thing to try now is to understand and apply an idea with vectorization using plate.

# load dependencies
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pymc3 as pm

from pymc3.distributions.timeseries import AR as ar
from scipy import optimize

from numpyro.examples.datasets import SP500, load_dataset

sns.set_context('talk')

# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
T = 1000
returns = returns[:T]
dates = dates[:T]

def create_model(returns):
    with pm.Model() as model_pmc3:
        
        phi = pm.Beta("phi", 20., 1.5)
        phi = 2*phi - 1
        sigma =  pm.InverseGamma('sigma', 2.5, 0.025)
        mu =  pm.Normal("mu", 0., 10.)

        h = ar('h', rho=[mu*(1-phi), phi], sigma=sigma ** 2, shape=len(returns))

        y = pm.Normal('y', 0., pm.math.exp(h/2), observed=returns)
        
        return model_pmc3
    
stochastic_vol_model = create_model(returns)

# run NFVI
nf = 'scale*10-loc*10'

with stochastic_vol_model:
    inference = pm.NFVI(nf, jitter=0.001)

n_smpl = 10000
inference.fit(n_smpl, obj_optimizer=pm.adam(learning_rate=.001), obj_n_mc=20)

traceNF = inference.approx.sample(n_smpl)

with stochastic_vol_model:
    posterior_predictive_nf = pm.sample_posterior_predictive(traceNF)

Thanks!

Thanks, Fritz. Following your suggestions I redefined the model:

# define model
def model(returns):
    
    phi = pyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = pyro.sample("mu", dist.Normal(0, 10))
    
    T = len(returns)
    means_white_noise = torch.tensor(mu*(1-phi)).repeat(T)
    vars_white_noise = torch.tensor(sigma2 ** 0.5).repeat(T)
    
    with pyro.plate("data", len(returns)):
        h = pyro.sample('h', dist.Normal(means_white_noise, vars_white_noise))
        h = pyro.ops.tensor_utils.convolve(h, phi ** torch.arange(T))[:T]

        y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)

However:

  1. It doesn’t look like improved computation time much (on a short toy runs it was even the opposite direction);
  2. now in the very end at the stage of neutra.transform_sample() I get this error (please note, torch.Size([5]) stands for len(returns), while torch.Size([10]) for number of samples):
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-49-6ca0cf6bcf2d> in <module>
    129     args = parser.parse_args(args=[])
    130 
--> 131     outDict = main(args)

<ipython-input-49-6ca0cf6bcf2d> in main(args)
     98 
     99         print(zs.shape)
--> 100         samples = neutra.transform_sample(zs)
    101 
    102         outDict['nf_neutra_mcmc'] = mcmc

~/anaconda3/lib/python3.7/site-packages/pyro/infer/reparam/neutra.py in transform_sample(self, latent)
    102         x_unconstrained = self.transform(latent)
    103         transformed_samples = {}
--> 104         for site, value in self.guide._unpack_latent(x_unconstrained):
    105             transform = biject_to(site["fn"].support)
    106             x_constrained = transform(value)

~/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in _unpack_latent(self, latent)
    609             event_dim = site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape)
    610             unconstrained_shape = broadcast_shape(unconstrained_shape,
--> 611                                                   batch_shape + (1,) * event_dim)
    612             unconstrained_value = latent[..., pos:pos + size].view(unconstrained_shape)
    613             yield site, unconstrained_value

~/anaconda3/lib/python3.7/site-packages/pyro/distributions/util.py in broadcast_shape(*shapes, **kwargs)
    140             elif reversed_shape[i] != size and (size != 1 or strict):
    141                 raise ValueError('shape mismatch: objects cannot be broadcast to a single shape: {}'.format(
--> 142                     ' vs '.join(map(str, shapes))))
    143     return tuple(reversed(reversed_shape))
    144 

ValueError: shape mismatch: objects cannot be broadcast to a single shape: torch.Size([5]) vs torch.Size([10])

Could you point me out if you think I am doing something incorrectly or have any workarounds?

Thanks,
Arturs

It actually uses predefined AR distribution. I tried to implement it in pyro by inheriting dist.TorchDistribution and similarly defining log_prob.

Similar to GaussianRankdomWalk, you can predefine AR. The log_prob computation should be much faster than the iterative version.

Haven’t really improved much.

Could you post your code? I can help optimize it.

Hi @fehiepsi, thanks for coming back. Let me reply in two blocks.

  1. The code is with the model definition as I stated above. Also I found out that the thing which improves the computation time (not very much though) is to call <VARIABLE>.repeat(T) method straight from the variables, without calling torch.tensor(<VARIABLE>).repeat(T). The full code, which runs faster (but still quite slow if you set the length of time-series to be longer, say T=100), and produces the error as I specified above, is as follows:
import pandas as pd
import numpy as np
import argparse
import logging
import os

import torch
torch.set_default_tensor_type('torch.FloatTensor') # also tried on CUDA, didn't perform much faster
torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
from pyro import optim, poutine
from pyro.distributions import constraints
from pyro.distributions.transforms import iterated, block_autoregressive, planar
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal, AutoNormalizingFlow, AutoIAFNormal
from pyro.infer.reparam import NeuTraReparam
from functools import partial

from numpyro.examples.datasets import SP500, load_dataset

logging.basicConfig(format='%(message)s', level=logging.INFO)

T = 200
# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:T]
dates = returns[:T]

# define model
def model(returns):
    
    phi = pyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = pyro.sample("mu", dist.Normal(0, 10))
    
    T = len(returns)
    means_white_noise = mu*(1-phi)
    vars_white_noise = sigma2 ** 0.5
    
    with pyro.plate("data", len(returns)):
        h = pyro.sample('h', dist.Normal(means_white_noise.repeat(T), vars_white_noise.repeat(T)))
        h = pyro.ops.tensor_utils.convolve(h, phi ** torch.arange(T))[:T]

        y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)
        

# define aux fn to fit the guide
def fit_guide(guide, args):
    pyro.clear_param_store()
    adam = optim.Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, adam, Trace_ELBO())
    for i in range(args.num_steps):
        loss = svi.step(args.data)
        if i % 500 == 0:
            logging.info("[{}]Elbo loss = {:.2f}".format(i, loss))

# define aux fn to run HMC
def run_hmc(args, model, print_summary=False):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, warmup_steps=args.num_warmup, num_samples=args.num_samples)
    mcmc.run(args.data)
    if print_summary:
        mcmc.summary()
    return mcmc


def main(args):
    pyro.set_rng_seed(args.rng_seed)
    
    outDict = {}
    
    assert args.autoguide in ['iaf', 'nf']

    if args.autoguide == 'iaf':
        # Fit an IAF
        logging.info('\nFitting a IAF autoguide ...')
        guide = AutoIAFNormal(model)
        fit_guide(guide, args)


        # Draw samples using NeuTra HMC
        logging.info('\nDrawing samples using IAF autoguide + NeuTra HMC ...')
        neutra = NeuTraReparam(guide.requires_grad_(False))
        neutra_model = poutine.reparam(model, config=lambda _: neutra)
        mcmc = run_hmc(args, neutra_model)
        zs = mcmc.get_samples()['phi_shared_latent']

        samples = neutra.transform_sample(zs)

        outDict['iaf_neutra_mcmc'] = mcmc
        outDict['iaf_neutra_samples'] = samples

    # this else ignored in current run
    elif args.autoguide == 'nf':
        # If we want the Normalizing Flow
        # fit autoguide
        logging.info('\nFitting a BNAF autoguide ...')
        guide = AutoNormalizingFlow(model, partial(iterated, 2, planar))
        fit_guide(guide, args)

        # Draw samples using NeuTra HMC
        logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
        neutra = NeuTraReparam(guide.requires_grad_(False))
        neutra_model = poutine.reparam(model, config=lambda _: neutra)
        mcmc = run_hmc(args, neutra_model)
        zs = mcmc.get_samples()['phi_shared_latent']
        
        samples = neutra.transform_sample(zs)

        outDict['nf_neutra_mcmc'] = mcmc
        outDict['nf_neutra_samples'] = samples
    
    
    return outDict


if __name__ == '__main__':
    assert pyro.__version__.startswith('1.2.1')
    parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer')
    parser.add_argument('-n', '--num-steps', default=1000, type=int,
                        help='number of SVI steps')
    parser.add_argument('-lr', '--learning-rate', default=1e-4, type=float,
                        help='learning rate for the Adam optimizer')
    parser.add_argument('--rng-seed', default=123, type=int,
                        help='RNG seed')
    parser.add_argument('--num-warmup', default=1000, type=int,
                        help='number of warmup steps for NUTS')
    parser.add_argument('--num-samples', default=2000, type=int,
                        help='number of samples to be drawn from NUTS')
    parser.add_argument('--data', default=torch.Tensor(returns), type=float,
                        help='Time-series of returns')
    parser.add_argument('--num-flows', default=4, type=int,
                        help='number of flows in the BNAF autoguide')
    parser.add_argument('--autoguide', default='nf', type=str,
                        help='Autoguide spec')

    args = parser.parse_args(args=[])
    
    outDict = main(args)
  1. Another little update from my side on this: looks like numpyro + pyro.plate + jax.lax convolution indeed improve the performance significantly, and all working fine for long time-series (1,000 obs) with the plain HMC. However, on the run for guide + NeuTra HMC (i.e. following NeuTra example with normalizing flows as suggested by @fehiepsi above), I get the memory error:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-22793ef5a047> in <module>
    109     numpyro.set_platform(args.device)
    110 
--> 111     main(args)

<ipython-input-1-22793ef5a047> in main(args)
     76     print("Finish training guide. Extract samples...")
     77     guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
---> 78                                            sample_shape=(args.num_samples,))['phi'].copy()
     79 
     80     transform = guide.get_transform(params)

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in sample_posterior(self, rng_key, params, sample_shape)
    270         latent_sample = handlers.substitute(handlers.seed(self._sample_latent, rng_key), params)(
    271             self.base_dist, sample_shape=sample_shape)
--> 272         return self._unpack_and_constrain(latent_sample, params)
    273 
    274     def median(self, params):

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in _unpack_and_constrain(self, latent_sample, params)
    230                 return transform_fn(self._inv_transforms, unpacked_samples)
    231 
--> 232         unpacked_samples = vmap(unpack_single_latent)(latent_sample)
    233         unpacked_samples = tree_map(lambda x: np.reshape(x, sample_shape + np.shape(x)[1:]),
    234                                     unpacked_samples)

~/anaconda3/lib/python3.7/site-packages/jax/api.py in batched_fun(*args)
    692     _check_axis_sizes(in_tree, args_flat, in_axes_flat)
    693     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 694                               lambda: _flatten_axes(out_tree(), out_axes))
    695     return tree_unflatten(out_tree(), out_flat)
    696 

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     38 def batch(fun, in_vals, in_dims, out_dim_dests):
     39   size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 40   out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
     41   return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
     42 

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims)
     44   with new_master(BatchTrace) as master:
     45     fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 46     out_vals = fun.call_wrapped(*in_vals)
     47     del master
     48   return out_vals, out_dims()

~/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     gen = None
    151 
--> 152     ans = self.f(*args, **dict(self.params, **kwargs))
    153     del args
    154     while stack:

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in unpack_single_latent(latent)
    226                 model = handlers.substitute(self.model, params)
    227                 return constrain_fn(model, self._inv_transforms, model_args,
--> 228                                     model_kwargs, unpacked_samples)
    229             else:
    230                 return transform_fn(self._inv_transforms, unpacked_samples)

~/anaconda3/lib/python3.7/site-packages/numpyro/infer/util.py in constrain_fn(model, transforms, model_args, model_kwargs, params, return_deterministic)
    134     params_constrained = transform_fn(transforms, params)
    135     substituted_model = substitute(model, base_param_map=params_constrained)
--> 136     model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
    137     return {k: v['value'] for k, v in model_trace.items() if (k in params) or
    138             (return_deterministic and v['type'] == 'deterministic')}

~/anaconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    147         :return: `OrderedDict` containing the execution trace.
    148         """
--> 149         self(*args, **kwargs)
    150         return self.trace
    151 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

<ipython-input-1-22793ef5a047> in stoch_vol_model()
     52         kernel = np.flip(np.reshape(kernel, (1, 1, len(kernel))), 2) #flip rhs
     53 
---> 54         h = lax.conv_general_dilated(h, kernel, [1],[(len(h)-1,len(h)-1)])[:N]
     55 
     56     return numpyro.sample('y', dist.Normal(0., (h / 2.)), obs=returns)

~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision)
    519       feature_group_count=feature_group_count,
    520       lhs_shape=lhs.shape, rhs_shape=rhs.shape,
--> 521       precision=_canonicalize_precision(precision))
    522 
    523 def dot(lhs, rhs, precision=None):

~/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
    157 
    158     tracers = map(top_trace.full_raise, args)
--> 159     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    160     if self.multiple_results:
    161       return map(full_lower, out_tracer)

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in process_primitive(self, primitive, tracers, params)
    112       # TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here
    113       batched_primitive = get_primitive_batcher(primitive)
--> 114       val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
    115       if primitive.multiple_results:
    116         return map(partial(BatchTracer, self), val_out, dim_out)

~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in _conv_general_dilated_batch_rule(batched_args, batch_dims, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision, **unused_kwargs)
   2041       dimension_numbers,
   2042       feature_group_count=lhs.shape[lhs_bdim] * feature_group_count,
-> 2043       precision=precision)
   2044     out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
   2045     return out, out_spec[1]

~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision)
    519       feature_group_count=feature_group_count,
    520       lhs_shape=lhs.shape, rhs_shape=rhs.shape,
--> 521       precision=_canonicalize_precision(precision))
    522 
    523 def dot(lhs, rhs, precision=None):

~/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
    154     top_trace = find_top_trace(args)
    155     if top_trace is None:
--> 156       return self.impl(*args, **kwargs)
    157 
    158     tracers = map(top_trace.full_raise, args)

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    159   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
    160   compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
--> 161   return compiled_fun(*args)
    162 
    163 @cache()

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, backend, tuple_args, result_handler, *args)
    236   if tuple_args:
    237     input_bufs = [make_tuple(input_bufs, device, backend)]
--> 238   out_buf = compiled.Execute(input_bufs)
    239   if FLAGS.jax_debug_nans:
    240     check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)

RuntimeError: Resource exhausted: Failed to allocate request for 59.62GiB (64016000000B) on device ordinal 0

The code for this numpyro part is here:

import argparse
from functools import partial
import os

from jax import lax, random, vmap
import jax.numpy as np
from jax.tree_util import tree_map

import numpyro
from numpyro import optim
from numpyro.contrib.autoguide import AutoContinuousELBO, AutoBNAFNormal
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI
from numpyro.infer.util import initialize_model, transformed_potential_energy

from numpyro.examples.datasets import SP500, load_dataset

T = 1000
# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:T]
dates = returns[:T]


# define model
def stoch_vol_model():
    phi = numpyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = numpyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = numpyro.sample("mu", dist.Normal(0, 10))

    N = len(returns)
    means_white_noise = mu * (1 - phi)
    vars_white_noise = sigma2 ** 0.5

    with numpyro.plate("data", len(returns)):
        h = numpyro.sample('h', dist.Normal(means_white_noise.repeat(N), vars_white_noise.repeat(N)))

        kernel = phi ** np.arange(N)

        # lax for convolution (very user unfriendly, by the way)
        h = np.reshape(h, (1, 1, len(h)))
        kernel = np.flip(np.reshape(kernel, (1, 1, len(kernel))), 2)  # flip rhs

        h = lax.conv_general_dilated(h, kernel, [1], [(len(h) - 1, len(h) - 1)])[:N]

    return numpyro.sample('y', dist.Normal(0., (h / 2.)), obs=returns)


def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(stoch_vol_model)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()

    # fit the guide
    guide = AutoBNAFNormal(stoch_vol_model, hidden_factors=[args.hidden_factor, args.hidden_factor])

    svi = SVI(stoch_vol_model, guide, optim.Adam(0.003), AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
                                           sample_shape=(args.num_samples,))['phi'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), stoch_vol_model)
    transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x))

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()

    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.4')
    parser = argparse.ArgumentParser(description="NeuTra HMC")
    parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int)
    parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
    parser.add_argument('--hidden-factor', nargs='?', default=8, type=int)
    parser.add_argument('--num-iters', nargs='?', default=1000, type=int)
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args([])

    numpyro.set_platform(args.device)

    main(args)

Sorry for the long one this time. Please let me know your thoughts.

Thanks!

1 Like

If you use NumPyro, then there are two ways to make it fast:

  • Using the following predefined AR (please double-check the implementation, I just quickly sketch it following GaussianRandomWalk pattern - GRW can be seen as a special case of AR, where coef=1)
from jax import lax, random
import jax.numpy as np

import numpyro
from numpyro.distributions import Distribution, Normal, constraints
from numpyro.distributions.util import validate_sample


class AR(Distribution):
    arg_constraints = {'init_values': constraints.real_vector,
                       'coefs': constraints.real_vector,
                       'scale': constraints.positive,
                       'num_steps': constraints.positive_integer}
    support = constraints.real_vector
    reparametrized_params = ['scale']

    def __init__(self, init, coef, scale=1., num_steps=1, validate_args=None):
        assert np.shape(num_steps) == ()
        assert init.shape[-1] == coef.shape[-1]
        self.init = init
        self.coef = coef
        self.scale = scale
        self.num_steps = num_steps
        batch_shape = lax.broadcast_shapes(np.shape(init)[:-1],
                                           np.shape(coef)[:-1],
                                           np.shape(scale))
        event_shape = (num_steps,)
        super(AR, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        # just a fake sample method to get initial params for autoguide and mcmc
        eps = random.normal(key, sample_shape + self.batch_shape + self.event_shape)
        return np.expand_dims(self.scale, -1) * eps

    @validate_sample
    def log_prob(self, value):
        assert value.shape[-1] == self.num_steps
        batch_shape = lax.broadcast_shapes(self.batch_shape, np.shape(value)[:-1])
        init = np.broadcast_to(self.init, batch_shape + self.init.shape[-1:])
        value = np.broadcast_to(value, batch_shape + (self.num_steps,))
        x = np.concatenate([init, value], -1)
        x_reg = np.stack([x[i:self.num_steps + i] for i in range(self.coef.shape[-1])], -1)
        noise = value - (self.coef * x_reg).sum(-1)
        return Normal(0, np.expand_dims(self.scale, -1)).log_prob(noise).sum(-1)

where coef is rho in PyMC3 and init is initial values of AR. You can set Normal prior to init in a separate numpyro.sample statement (PyMC3 used flat prior for init by default, so AR will be improper distribution by default).

  • Use lax.scan with reparameterized AR (as in this topic).

Note that nf = 'scale*10-loc*10' is equivalent to numpyro.contrib.autoguide.AutoDiagonalNormal, so no need to use AutoBNAFNormal.

Edit: I just test, MCMC is pretty fast: >100it/s. SVI finishes in seconds.

def stoch_vol_model():
    phi = numpyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = numpyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = numpyro.sample("mu", dist.Normal(0, 10))
    h_init = numpyro.sample("h_init", dist.Normal(0, 1), sample_shape=(2,))
    N = len(returns)
    h_ar = numpyro.sample("h_ar", AR(h_init, np.stack([phi, mu * (1 - phi)]), sigma2, N - 2))
    h = np.concatenate([h_init, h_ar], -1)
    return numpyro.sample('y', dist.Normal(0., np.exp(h / 2.)), obs=returns)

Btw, I think that it is better to follow the stable example. Using latent AR distribution likes this is not scalable and I think with that high number of latent variables, the inferences would be unstable! But using AR as an observation site would be fine.

Thanks, @fehiepsi, will check this out.

When you say

AR as an observation site would be fine (but not necessarily for latent variables, though)

Do you mean that with plain MCMC / HMC it should work fine, but not necessarily with Normalizing Flows / NeuTra ?

Your model has h as a “latent” variable with AR distribution. By “observed” variable, I meant a site with obs=... keyword. In PyMC3, it is observed=... keyword (see PyMC3 AR example).

Hey, thanks. Sorry for a silly questions, but as “observed” we are passing the observed returns, i.e. data points, right? It has nothing to do with the AR we defined above? Or I miss something.

Yes, I just meant that using AR for observable variable would be good, using AR for a latent variable would be bad. Please ignore that sentence if you want to use AR as a latent variable. You can use the AR class which I defined above; it is equivalent to PyMC3 one.

Hi @fehiepsi,

Thanks a lot for your responses. It turned out working in the very end, in particular with numpyro.contrib.autoguide.AutoDiagonalNormal.
I just needed to update the AR definition a little bit to allow for a constant term in AR process.

However, when I try AutoBNAFNormal (or AutoIAFNormal), I keep having memory issues (i.e. as I stated here). Do you happen to know if there are any workarounds on this?

Thanks,
Arthur

@ameshkovskiy By default, BNAF has hidden_factors=[8, 8], which is quite large (because the number of latent variables in your model is large, say d=1000), so the weight of the middle layer will have shape (8 x 1000) x (8 x 1000), which is very large. I am not sure if BNAF, and in general other normalization flows, will work in that scenario (in the paper, the authors only mention experiments for d <= 64).

Thanks!

Hi there,

To add to the conversation on speed, I am running the following inference problem in pyro==1.7.0 and pymc3==3.11.4

data = torch.distributions.Normal(0., 1.).sample((200,))

# pyro
def model(mu_0, sigma_0, sigma, data):
    mu = pyro.sample("mu", dist.Normal(mu_0, sigma_0))
    with pyro.plate("data"):
        y = pyro.sample("y", dist.Normal(mu, sigma), obs=data)

mcmc = MCMC(
    NUTS(partial(model, 4., 2., 1.)),
    warmup_steps=500,
    num_samples=10_000,
    num_chains=1,
)
mcmc.run(data)

# pymc3
with pm.Model() as model:
    mu = pm.Normal("mu", mu=4., sigma=2.)
    x = pm.Normal("observed", mu=mu, sigma=1., observed=data.numpy())
    trace_pm = pm.sample(10000, tune=500, chains=1)

PyMC3 terminates in 3 secs while pyro takes 13 secs. The difference is quite striking - is this due to issues with implementation?

pyro runs on pytorch. to my understanding pymc3 has a theano backend. why would you expect two different backends to yield the same runtime?

in particular pytorch isn’t well-optimized for the small model/small tensor regime. jax (and therefore numpyro) is much more performant in that regime.

@martinjankowiak thanks for your answer. If I understand correctly, the difference is due to the backend framework.

Is there a case in which it is possible to speed up pytorch? From your answer, it should be the case that in big models/big tensors regimes pytorch can be a reasonable choice.

(For context, in my actual application I will deal with the latter case)

where appropriate you can use JitTrace_ELBO and its analogs. depending on the use case this should give you a speed-up. but there are limits to the degree to which you can speed-up pytorch code, especially in the small tensor regime.

the main use case of pytorch is neural networks, where large tensor ops are the norm. as a consequence the pytorch devs have never worried too much about possible overhead slowdown in pytorch programs that execute large numbers of small tensor ops

1 Like