Pyro performs dramatically slower than PyMC3 with Normalizing Flows on stochastic-volatility model inference

Hi @fehiepsi, thanks for coming back. Let me reply in two blocks.

  1. The code is with the model definition as I stated above. Also I found out that the thing which improves the computation time (not very much though) is to call <VARIABLE>.repeat(T) method straight from the variables, without calling torch.tensor(<VARIABLE>).repeat(T). The full code, which runs faster (but still quite slow if you set the length of time-series to be longer, say T=100), and produces the error as I specified above, is as follows:
import pandas as pd
import numpy as np
import argparse
import logging
import os

import torch
torch.set_default_tensor_type('torch.FloatTensor') # also tried on CUDA, didn't perform much faster
torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
from pyro import optim, poutine
from pyro.distributions import constraints
from pyro.distributions.transforms import iterated, block_autoregressive, planar
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal, AutoNormalizingFlow, AutoIAFNormal
from pyro.infer.reparam import NeuTraReparam
from functools import partial

from numpyro.examples.datasets import SP500, load_dataset

logging.basicConfig(format='%(message)s', level=logging.INFO)

T = 200
# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:T]
dates = returns[:T]

# define model
def model(returns):
    
    phi = pyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = pyro.sample("mu", dist.Normal(0, 10))
    
    T = len(returns)
    means_white_noise = mu*(1-phi)
    vars_white_noise = sigma2 ** 0.5
    
    with pyro.plate("data", len(returns)):
        h = pyro.sample('h', dist.Normal(means_white_noise.repeat(T), vars_white_noise.repeat(T)))
        h = pyro.ops.tensor_utils.convolve(h, phi ** torch.arange(T))[:T]

        y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)
        

# define aux fn to fit the guide
def fit_guide(guide, args):
    pyro.clear_param_store()
    adam = optim.Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, adam, Trace_ELBO())
    for i in range(args.num_steps):
        loss = svi.step(args.data)
        if i % 500 == 0:
            logging.info("[{}]Elbo loss = {:.2f}".format(i, loss))

# define aux fn to run HMC
def run_hmc(args, model, print_summary=False):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, warmup_steps=args.num_warmup, num_samples=args.num_samples)
    mcmc.run(args.data)
    if print_summary:
        mcmc.summary()
    return mcmc


def main(args):
    pyro.set_rng_seed(args.rng_seed)
    
    outDict = {}
    
    assert args.autoguide in ['iaf', 'nf']

    if args.autoguide == 'iaf':
        # Fit an IAF
        logging.info('\nFitting a IAF autoguide ...')
        guide = AutoIAFNormal(model)
        fit_guide(guide, args)


        # Draw samples using NeuTra HMC
        logging.info('\nDrawing samples using IAF autoguide + NeuTra HMC ...')
        neutra = NeuTraReparam(guide.requires_grad_(False))
        neutra_model = poutine.reparam(model, config=lambda _: neutra)
        mcmc = run_hmc(args, neutra_model)
        zs = mcmc.get_samples()['phi_shared_latent']

        samples = neutra.transform_sample(zs)

        outDict['iaf_neutra_mcmc'] = mcmc
        outDict['iaf_neutra_samples'] = samples

    # this else ignored in current run
    elif args.autoguide == 'nf':
        # If we want the Normalizing Flow
        # fit autoguide
        logging.info('\nFitting a BNAF autoguide ...')
        guide = AutoNormalizingFlow(model, partial(iterated, 2, planar))
        fit_guide(guide, args)

        # Draw samples using NeuTra HMC
        logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
        neutra = NeuTraReparam(guide.requires_grad_(False))
        neutra_model = poutine.reparam(model, config=lambda _: neutra)
        mcmc = run_hmc(args, neutra_model)
        zs = mcmc.get_samples()['phi_shared_latent']
        
        samples = neutra.transform_sample(zs)

        outDict['nf_neutra_mcmc'] = mcmc
        outDict['nf_neutra_samples'] = samples
    
    
    return outDict


if __name__ == '__main__':
    assert pyro.__version__.startswith('1.2.1')
    parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer')
    parser.add_argument('-n', '--num-steps', default=1000, type=int,
                        help='number of SVI steps')
    parser.add_argument('-lr', '--learning-rate', default=1e-4, type=float,
                        help='learning rate for the Adam optimizer')
    parser.add_argument('--rng-seed', default=123, type=int,
                        help='RNG seed')
    parser.add_argument('--num-warmup', default=1000, type=int,
                        help='number of warmup steps for NUTS')
    parser.add_argument('--num-samples', default=2000, type=int,
                        help='number of samples to be drawn from NUTS')
    parser.add_argument('--data', default=torch.Tensor(returns), type=float,
                        help='Time-series of returns')
    parser.add_argument('--num-flows', default=4, type=int,
                        help='number of flows in the BNAF autoguide')
    parser.add_argument('--autoguide', default='nf', type=str,
                        help='Autoguide spec')

    args = parser.parse_args(args=[])
    
    outDict = main(args)
  1. Another little update from my side on this: looks like numpyro + pyro.plate + jax.lax convolution indeed improve the performance significantly, and all working fine for long time-series (1,000 obs) with the plain HMC. However, on the run for guide + NeuTra HMC (i.e. following NeuTra example with normalizing flows as suggested by @fehiepsi above), I get the memory error:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-22793ef5a047> in <module>
    109     numpyro.set_platform(args.device)
    110 
--> 111     main(args)

<ipython-input-1-22793ef5a047> in main(args)
     76     print("Finish training guide. Extract samples...")
     77     guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
---> 78                                            sample_shape=(args.num_samples,))['phi'].copy()
     79 
     80     transform = guide.get_transform(params)

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in sample_posterior(self, rng_key, params, sample_shape)
    270         latent_sample = handlers.substitute(handlers.seed(self._sample_latent, rng_key), params)(
    271             self.base_dist, sample_shape=sample_shape)
--> 272         return self._unpack_and_constrain(latent_sample, params)
    273 
    274     def median(self, params):

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in _unpack_and_constrain(self, latent_sample, params)
    230                 return transform_fn(self._inv_transforms, unpacked_samples)
    231 
--> 232         unpacked_samples = vmap(unpack_single_latent)(latent_sample)
    233         unpacked_samples = tree_map(lambda x: np.reshape(x, sample_shape + np.shape(x)[1:]),
    234                                     unpacked_samples)

~/anaconda3/lib/python3.7/site-packages/jax/api.py in batched_fun(*args)
    692     _check_axis_sizes(in_tree, args_flat, in_axes_flat)
    693     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 694                               lambda: _flatten_axes(out_tree(), out_axes))
    695     return tree_unflatten(out_tree(), out_flat)
    696 

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     38 def batch(fun, in_vals, in_dims, out_dim_dests):
     39   size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 40   out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
     41   return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
     42 

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims)
     44   with new_master(BatchTrace) as master:
     45     fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 46     out_vals = fun.call_wrapped(*in_vals)
     47     del master
     48   return out_vals, out_dims()

~/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    150     gen = None
    151 
--> 152     ans = self.f(*args, **dict(self.params, **kwargs))
    153     del args
    154     while stack:

~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in unpack_single_latent(latent)
    226                 model = handlers.substitute(self.model, params)
    227                 return constrain_fn(model, self._inv_transforms, model_args,
--> 228                                     model_kwargs, unpacked_samples)
    229             else:
    230                 return transform_fn(self._inv_transforms, unpacked_samples)

~/anaconda3/lib/python3.7/site-packages/numpyro/infer/util.py in constrain_fn(model, transforms, model_args, model_kwargs, params, return_deterministic)
    134     params_constrained = transform_fn(transforms, params)
    135     substituted_model = substitute(model, base_param_map=params_constrained)
--> 136     model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
    137     return {k: v['value'] for k, v in model_trace.items() if (k in params) or
    138             (return_deterministic and v['type'] == 'deterministic')}

~/anaconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    147         :return: `OrderedDict` containing the execution trace.
    148         """
--> 149         self(*args, **kwargs)
    150         return self.trace
    151 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

<ipython-input-1-22793ef5a047> in stoch_vol_model()
     52         kernel = np.flip(np.reshape(kernel, (1, 1, len(kernel))), 2) #flip rhs
     53 
---> 54         h = lax.conv_general_dilated(h, kernel, [1],[(len(h)-1,len(h)-1)])[:N]
     55 
     56     return numpyro.sample('y', dist.Normal(0., (h / 2.)), obs=returns)

~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision)
    519       feature_group_count=feature_group_count,
    520       lhs_shape=lhs.shape, rhs_shape=rhs.shape,
--> 521       precision=_canonicalize_precision(precision))
    522 
    523 def dot(lhs, rhs, precision=None):

~/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
    157 
    158     tracers = map(top_trace.full_raise, args)
--> 159     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    160     if self.multiple_results:
    161       return map(full_lower, out_tracer)

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in process_primitive(self, primitive, tracers, params)
    112       # TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here
    113       batched_primitive = get_primitive_batcher(primitive)
--> 114       val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
    115       if primitive.multiple_results:
    116         return map(partial(BatchTracer, self), val_out, dim_out)

~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in _conv_general_dilated_batch_rule(batched_args, batch_dims, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision, **unused_kwargs)
   2041       dimension_numbers,
   2042       feature_group_count=lhs.shape[lhs_bdim] * feature_group_count,
-> 2043       precision=precision)
   2044     out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
   2045     return out, out_spec[1]

~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision)
    519       feature_group_count=feature_group_count,
    520       lhs_shape=lhs.shape, rhs_shape=rhs.shape,
--> 521       precision=_canonicalize_precision(precision))
    522 
    523 def dot(lhs, rhs, precision=None):

~/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
    154     top_trace = find_top_trace(args)
    155     if top_trace is None:
--> 156       return self.impl(*args, **kwargs)
    157 
    158     tracers = map(top_trace.full_raise, args)

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    159   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
    160   compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
--> 161   return compiled_fun(*args)
    162 
    163 @cache()

~/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, backend, tuple_args, result_handler, *args)
    236   if tuple_args:
    237     input_bufs = [make_tuple(input_bufs, device, backend)]
--> 238   out_buf = compiled.Execute(input_bufs)
    239   if FLAGS.jax_debug_nans:
    240     check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)

RuntimeError: Resource exhausted: Failed to allocate request for 59.62GiB (64016000000B) on device ordinal 0

The code for this numpyro part is here:

import argparse
from functools import partial
import os

from jax import lax, random, vmap
import jax.numpy as np
from jax.tree_util import tree_map

import numpyro
from numpyro import optim
from numpyro.contrib.autoguide import AutoContinuousELBO, AutoBNAFNormal
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI
from numpyro.infer.util import initialize_model, transformed_potential_energy

from numpyro.examples.datasets import SP500, load_dataset

T = 1000
# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:T]
dates = returns[:T]


# define model
def stoch_vol_model():
    phi = numpyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = numpyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = numpyro.sample("mu", dist.Normal(0, 10))

    N = len(returns)
    means_white_noise = mu * (1 - phi)
    vars_white_noise = sigma2 ** 0.5

    with numpyro.plate("data", len(returns)):
        h = numpyro.sample('h', dist.Normal(means_white_noise.repeat(N), vars_white_noise.repeat(N)))

        kernel = phi ** np.arange(N)

        # lax for convolution (very user unfriendly, by the way)
        h = np.reshape(h, (1, 1, len(h)))
        kernel = np.flip(np.reshape(kernel, (1, 1, len(kernel))), 2)  # flip rhs

        h = lax.conv_general_dilated(h, kernel, [1], [(len(h) - 1, len(h) - 1)])[:N]

    return numpyro.sample('y', dist.Normal(0., (h / 2.)), obs=returns)


def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(stoch_vol_model)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()

    # fit the guide
    guide = AutoBNAFNormal(stoch_vol_model, hidden_factors=[args.hidden_factor, args.hidden_factor])

    svi = SVI(stoch_vol_model, guide, optim.Adam(0.003), AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
                                           sample_shape=(args.num_samples,))['phi'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), stoch_vol_model)
    transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x))

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()

    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))


if __name__ == "__main__":
    assert numpyro.__version__.startswith('0.2.4')
    parser = argparse.ArgumentParser(description="NeuTra HMC")
    parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int)
    parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
    parser.add_argument('--hidden-factor', nargs='?', default=8, type=int)
    parser.add_argument('--num-iters', nargs='?', default=1000, type=int)
    parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args([])

    numpyro.set_platform(args.device)

    main(args)

Sorry for the long one this time. Please let me know your thoughts.

Thanks!

1 Like