Hi @fehiepsi, thanks for coming back. Let me reply in two blocks.
- The code is with the model definition as I stated above. Also I found out that the thing which improves the computation time (not very much though) is to call
<VARIABLE>.repeat(T)
method straight from the variables, without callingtorch.tensor(<VARIABLE>).repeat(T)
. The full code, which runs faster (but still quite slow if you set the length of time-series to be longer, sayT=100
), and produces the error as I specified above, is as follows:
import pandas as pd
import numpy as np
import argparse
import logging
import os
import torch
torch.set_default_tensor_type('torch.FloatTensor') # also tried on CUDA, didn't perform much faster
torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
from pyro import optim, poutine
from pyro.distributions import constraints
from pyro.distributions.transforms import iterated, block_autoregressive, planar
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal, AutoNormalizingFlow, AutoIAFNormal
from pyro.infer.reparam import NeuTraReparam
from functools import partial
from numpyro.examples.datasets import SP500, load_dataset
logging.basicConfig(format='%(message)s', level=logging.INFO)
T = 200
# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:T]
dates = returns[:T]
# define model
def model(returns):
phi = pyro.sample("phi", dist.Beta(20, 1.5))
phi = 2 * phi - 1
sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
mu = pyro.sample("mu", dist.Normal(0, 10))
T = len(returns)
means_white_noise = mu*(1-phi)
vars_white_noise = sigma2 ** 0.5
with pyro.plate("data", len(returns)):
h = pyro.sample('h', dist.Normal(means_white_noise.repeat(T), vars_white_noise.repeat(T)))
h = pyro.ops.tensor_utils.convolve(h, phi ** torch.arange(T))[:T]
y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)
# define aux fn to fit the guide
def fit_guide(guide, args):
pyro.clear_param_store()
adam = optim.Adam({'lr': args.learning_rate})
svi = SVI(model, guide, adam, Trace_ELBO())
for i in range(args.num_steps):
loss = svi.step(args.data)
if i % 500 == 0:
logging.info("[{}]Elbo loss = {:.2f}".format(i, loss))
# define aux fn to run HMC
def run_hmc(args, model, print_summary=False):
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, warmup_steps=args.num_warmup, num_samples=args.num_samples)
mcmc.run(args.data)
if print_summary:
mcmc.summary()
return mcmc
def main(args):
pyro.set_rng_seed(args.rng_seed)
outDict = {}
assert args.autoguide in ['iaf', 'nf']
if args.autoguide == 'iaf':
# Fit an IAF
logging.info('\nFitting a IAF autoguide ...')
guide = AutoIAFNormal(model)
fit_guide(guide, args)
# Draw samples using NeuTra HMC
logging.info('\nDrawing samples using IAF autoguide + NeuTra HMC ...')
neutra = NeuTraReparam(guide.requires_grad_(False))
neutra_model = poutine.reparam(model, config=lambda _: neutra)
mcmc = run_hmc(args, neutra_model)
zs = mcmc.get_samples()['phi_shared_latent']
samples = neutra.transform_sample(zs)
outDict['iaf_neutra_mcmc'] = mcmc
outDict['iaf_neutra_samples'] = samples
# this else ignored in current run
elif args.autoguide == 'nf':
# If we want the Normalizing Flow
# fit autoguide
logging.info('\nFitting a BNAF autoguide ...')
guide = AutoNormalizingFlow(model, partial(iterated, 2, planar))
fit_guide(guide, args)
# Draw samples using NeuTra HMC
logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
neutra = NeuTraReparam(guide.requires_grad_(False))
neutra_model = poutine.reparam(model, config=lambda _: neutra)
mcmc = run_hmc(args, neutra_model)
zs = mcmc.get_samples()['phi_shared_latent']
samples = neutra.transform_sample(zs)
outDict['nf_neutra_mcmc'] = mcmc
outDict['nf_neutra_samples'] = samples
return outDict
if __name__ == '__main__':
assert pyro.__version__.startswith('1.2.1')
parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer')
parser.add_argument('-n', '--num-steps', default=1000, type=int,
help='number of SVI steps')
parser.add_argument('-lr', '--learning-rate', default=1e-4, type=float,
help='learning rate for the Adam optimizer')
parser.add_argument('--rng-seed', default=123, type=int,
help='RNG seed')
parser.add_argument('--num-warmup', default=1000, type=int,
help='number of warmup steps for NUTS')
parser.add_argument('--num-samples', default=2000, type=int,
help='number of samples to be drawn from NUTS')
parser.add_argument('--data', default=torch.Tensor(returns), type=float,
help='Time-series of returns')
parser.add_argument('--num-flows', default=4, type=int,
help='number of flows in the BNAF autoguide')
parser.add_argument('--autoguide', default='nf', type=str,
help='Autoguide spec')
args = parser.parse_args(args=[])
outDict = main(args)
- Another little update from my side on this: looks like
numpyro
+pyro.plate
+jax.lax
convolution indeed improve the performance significantly, and all working fine for long time-series (1,000 obs) with the plain HMC. However, on the run for guide + NeuTra HMC (i.e. following NeuTra example with normalizing flows as suggested by @fehiepsi above), I get the memory error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-1-22793ef5a047> in <module>
109 numpyro.set_platform(args.device)
110
--> 111 main(args)
<ipython-input-1-22793ef5a047> in main(args)
76 print("Finish training guide. Extract samples...")
77 guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
---> 78 sample_shape=(args.num_samples,))['phi'].copy()
79
80 transform = guide.get_transform(params)
~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in sample_posterior(self, rng_key, params, sample_shape)
270 latent_sample = handlers.substitute(handlers.seed(self._sample_latent, rng_key), params)(
271 self.base_dist, sample_shape=sample_shape)
--> 272 return self._unpack_and_constrain(latent_sample, params)
273
274 def median(self, params):
~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in _unpack_and_constrain(self, latent_sample, params)
230 return transform_fn(self._inv_transforms, unpacked_samples)
231
--> 232 unpacked_samples = vmap(unpack_single_latent)(latent_sample)
233 unpacked_samples = tree_map(lambda x: np.reshape(x, sample_shape + np.shape(x)[1:]),
234 unpacked_samples)
~/anaconda3/lib/python3.7/site-packages/jax/api.py in batched_fun(*args)
692 _check_axis_sizes(in_tree, args_flat, in_axes_flat)
693 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 694 lambda: _flatten_axes(out_tree(), out_axes))
695 return tree_unflatten(out_tree(), out_flat)
696
~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
38 def batch(fun, in_vals, in_dims, out_dim_dests):
39 size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 40 out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
41 return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
42
~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims)
44 with new_master(BatchTrace) as master:
45 fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 46 out_vals = fun.call_wrapped(*in_vals)
47 del master
48 return out_vals, out_dims()
~/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
150 gen = None
151
--> 152 ans = self.f(*args, **dict(self.params, **kwargs))
153 del args
154 while stack:
~/anaconda3/lib/python3.7/site-packages/numpyro/contrib/autoguide/__init__.py in unpack_single_latent(latent)
226 model = handlers.substitute(self.model, params)
227 return constrain_fn(model, self._inv_transforms, model_args,
--> 228 model_kwargs, unpacked_samples)
229 else:
230 return transform_fn(self._inv_transforms, unpacked_samples)
~/anaconda3/lib/python3.7/site-packages/numpyro/infer/util.py in constrain_fn(model, transforms, model_args, model_kwargs, params, return_deterministic)
134 params_constrained = transform_fn(transforms, params)
135 substituted_model = substitute(model, base_param_map=params_constrained)
--> 136 model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
137 return {k: v['value'] for k, v in model_trace.items() if (k in params) or
138 (return_deterministic and v['type'] == 'deterministic')}
~/anaconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
147 :return: `OrderedDict` containing the execution trace.
148 """
--> 149 self(*args, **kwargs)
150 return self.trace
151
~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
61 def __call__(self, *args, **kwargs):
62 with self:
---> 63 return self.fn(*args, **kwargs)
64
65
~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
61 def __call__(self, *args, **kwargs):
62 with self:
---> 63 return self.fn(*args, **kwargs)
64
65
~/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
61 def __call__(self, *args, **kwargs):
62 with self:
---> 63 return self.fn(*args, **kwargs)
64
65
<ipython-input-1-22793ef5a047> in stoch_vol_model()
52 kernel = np.flip(np.reshape(kernel, (1, 1, len(kernel))), 2) #flip rhs
53
---> 54 h = lax.conv_general_dilated(h, kernel, [1],[(len(h)-1,len(h)-1)])[:N]
55
56 return numpyro.sample('y', dist.Normal(0., (h / 2.)), obs=returns)
~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision)
519 feature_group_count=feature_group_count,
520 lhs_shape=lhs.shape, rhs_shape=rhs.shape,
--> 521 precision=_canonicalize_precision(precision))
522
523 def dot(lhs, rhs, precision=None):
~/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
157
158 tracers = map(top_trace.full_raise, args)
--> 159 out_tracer = top_trace.process_primitive(self, tracers, kwargs)
160 if self.multiple_results:
161 return map(full_lower, out_tracer)
~/anaconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in process_primitive(self, primitive, tracers, params)
112 # TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here
113 batched_primitive = get_primitive_batcher(primitive)
--> 114 val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
115 if primitive.multiple_results:
116 return map(partial(BatchTracer, self), val_out, dim_out)
~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in _conv_general_dilated_batch_rule(batched_args, batch_dims, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision, **unused_kwargs)
2041 dimension_numbers,
2042 feature_group_count=lhs.shape[lhs_bdim] * feature_group_count,
-> 2043 precision=precision)
2044 out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
2045 return out, out_spec[1]
~/anaconda3/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision)
519 feature_group_count=feature_group_count,
520 lhs_shape=lhs.shape, rhs_shape=rhs.shape,
--> 521 precision=_canonicalize_precision(precision))
522
523 def dot(lhs, rhs, precision=None):
~/anaconda3/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
154 top_trace = find_top_trace(args)
155 if top_trace is None:
--> 156 return self.impl(*args, **kwargs)
157
158 tracers = map(top_trace.full_raise, args)
~/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
159 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
160 compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
--> 161 return compiled_fun(*args)
162
163 @cache()
~/anaconda3/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, backend, tuple_args, result_handler, *args)
236 if tuple_args:
237 input_bufs = [make_tuple(input_bufs, device, backend)]
--> 238 out_buf = compiled.Execute(input_bufs)
239 if FLAGS.jax_debug_nans:
240 check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
RuntimeError: Resource exhausted: Failed to allocate request for 59.62GiB (64016000000B) on device ordinal 0
The code for this numpyro
part is here:
import argparse
from functools import partial
import os
from jax import lax, random, vmap
import jax.numpy as np
from jax.tree_util import tree_map
import numpyro
from numpyro import optim
from numpyro.contrib.autoguide import AutoContinuousELBO, AutoBNAFNormal
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI
from numpyro.infer.util import initialize_model, transformed_potential_energy
from numpyro.examples.datasets import SP500, load_dataset
T = 1000
# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:T]
dates = returns[:T]
# define model
def stoch_vol_model():
phi = numpyro.sample("phi", dist.Beta(20, 1.5))
phi = 2 * phi - 1
sigma2 = numpyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
mu = numpyro.sample("mu", dist.Normal(0, 10))
N = len(returns)
means_white_noise = mu * (1 - phi)
vars_white_noise = sigma2 ** 0.5
with numpyro.plate("data", len(returns)):
h = numpyro.sample('h', dist.Normal(means_white_noise.repeat(N), vars_white_noise.repeat(N)))
kernel = phi ** np.arange(N)
# lax for convolution (very user unfriendly, by the way)
h = np.reshape(h, (1, 1, len(h)))
kernel = np.flip(np.reshape(kernel, (1, 1, len(kernel))), 2) # flip rhs
h = lax.conv_general_dilated(h, kernel, [1], [(len(h) - 1, len(h) - 1)])[:N]
return numpyro.sample('y', dist.Normal(0., (h / 2.)), obs=returns)
def main(args):
print("Start vanilla HMC...")
nuts_kernel = NUTS(stoch_vol_model)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()
# fit the guide
guide = AutoBNAFNormal(stoch_vol_model, hidden_factors=[args.hidden_factor, args.hidden_factor])
svi = SVI(stoch_vol_model, guide, optim.Adam(0.003), AutoContinuousELBO())
svi_state = svi.init(random.PRNGKey(1))
print("Start training guide...")
last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
params = svi.get_params(last_state)
print("Finish training guide. Extract samples...")
guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
sample_shape=(args.num_samples,))['phi'].copy()
transform = guide.get_transform(params)
_, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), stoch_vol_model)
transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform)
transformed_constrain_fn = lambda x: constrain_fn(transform(x))
print("\nStart NeuTra HMC...")
nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
init_params = np.zeros(guide.latent_size)
mcmc.run(random.PRNGKey(3), init_params=init_params)
mcmc.print_summary()
zs = mcmc.get_samples()
print("Transform samples into unwarped space...")
samples = vmap(transformed_constrain_fn)(zs)
print_summary(tree_map(lambda x: x[None, ...], samples))
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.2.4')
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
parser.add_argument('--hidden-factor', nargs='?', default=8, type=int)
parser.add_argument('--num-iters', nargs='?', default=1000, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
args = parser.parse_args([])
numpyro.set_platform(args.device)
main(args)
Sorry for the long one this time. Please let me know your thoughts.
Thanks!