Hello. I’m experimenting with an MA(1) model for time series analysis, but I’m facing an error and I would like to ask for some advice.
I am running the following code:
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample
import numpyro.distributions as dist
from jax import random
from numpyro.contrib.control_flow import scan
import jax.numpy as jnp
def model(y):
n_series = y.shape[1]
with numpyro.plate("plate", n_series):
theta = numpyro.sample("theta", dist.Normal(0.0, 70.0))
def transition_fn(carry, t_idx):
epsilon_prev = carry
mean = theta*epsilon_prev
y_pred = numpyro.sample("y_pred", dist.Normal(mean, 1.0))
carry_new = y_pred - mean
return carry_new, y_pred
time_indices = jnp.arange(0, y.shape[0])
init_carry = jnp.empty(shape=(n_series), dtype=float)
with numpyro.handlers.condition(data={"y_pred": y}):
_, y_predictions = scan(
transition_fn,
init_carry,
time_indices,
)
rng_key = random.PRNGKey(12)
seeds = random.split(rng_key, 10)
n_series = 1
nuts_kernel = NUTS(model, init_strategy=init_to_sample)
train = random.normal(shape=(200,n_series), key=seeds[0])
mcmc = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000, num_chains=1)
mcmc.run(seeds[1], y=train)
mcmc.print_summary()
This produces the following traceback:
Traceback (most recent call last):
File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 37, in <module>
mcmc.run(seeds[1], y=train)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 593, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 381, in _single_chain_mcmc
init_state = self.sampler.init(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 706, in init
init_params = self._init_state(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 652, in _init_state
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/util.py", line 676, in initialize_model
substituted_model(*model_args, **model_kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 24, in model
_, y_predictions = scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 305, in scan_wrapper
last_carry, (pytree_trace, ys) = lax.scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 278, in scan
return tree_unflatten(out_tree, out)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/tree_util.py", line 71, in tree_unflatten
return treedef.unflatten(leaves)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 153, in tree_unflatten
return cls(**dict(zip(sorted(cls.arg_constraints.keys()), params)))
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 99, in __call__
return super().__call__(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/continuous.py", line 1701, in __init__
super(Normal, self).__init__(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 177, in __init__
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Normal distribution got invalid loc parameter.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/contextlib.py", line 131, in __exit__
self.gen.throw(type, value, traceback)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 80, in validation_enabled
yield
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/util.py", line 676, in initialize_model
substituted_model(*model_args, **model_kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 24, in model
_, y_predictions = scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 305, in scan_wrapper
last_carry, (pytree_trace, ys) = lax.scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 153, in tree_unflatten
return cls(**dict(zip(sorted(cls.arg_constraints.keys()), params)))
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 99, in __call__
return super().__call__(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/continuous.py", line 1701, in __init__
super(Normal, self).__init__(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 177, in __init__
raise ValueError(
ValueError: Normal distribution got invalid loc parameter.
The key line would be ValueError: Normal distribution got invalid loc parameter.
. From debugging it appears that the loc parameter diverges to infinity. This occurs before the chain becomes to warmup. Also I think the loc parameter in question must be the mean
, but I don’t see why this would diverge since it’s the product of two normally distributed variabes: theta
and y_pred - mean
, which both have finite scale parameters.
When looking at a single series this problem can be easily fixed by changing the random seed or using a more informative prior, since theta = numpyro.sample("theta", dist.Normal(0.0, 70.0))
gives a pretty wide range. But the more series I include the less reliable these solutions become. For example with n_series=50
and
with numpyro.plate("plate", n_series):
theta = numpyro.sample("theta", dist.Normal(0.0, 1.0))
I cannot find a seed which allows the MCMC sampling to be successful. So even though for a single series the problem can be sidestepped, I think for many series I will need to solve it.
Does anybody have any advice on this issue? Would be glad to hear it.