Exploding loc parameter in MA(1) time series model

Hello. I’m experimenting with an MA(1) model for time series analysis, but I’m facing an error and I would like to ask for some advice.

I am running the following code:

``````import numpyro
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample
import numpyro.distributions as dist
from jax import random
from numpyro.contrib.control_flow import scan
import jax.numpy as jnp

def model(y):

n_series = y.shape[1]
with numpyro.plate("plate", n_series):
theta = numpyro.sample("theta", dist.Normal(0.0, 70.0))

def transition_fn(carry, t_idx):
epsilon_prev = carry
mean = theta*epsilon_prev
y_pred = numpyro.sample("y_pred", dist.Normal(mean, 1.0))
carry_new = y_pred - mean
return carry_new, y_pred

time_indices = jnp.arange(0, y.shape[0])
init_carry = jnp.empty(shape=(n_series), dtype=float)
with numpyro.handlers.condition(data={"y_pred": y}):
_, y_predictions = scan(
transition_fn,
init_carry,
time_indices,
)

rng_key = random.PRNGKey(12)
seeds = random.split(rng_key, 10)
n_series = 1
nuts_kernel = NUTS(model, init_strategy=init_to_sample)
train = random.normal(shape=(200,n_series), key=seeds[0])
mcmc = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000, num_chains=1)
mcmc.run(seeds[1], y=train)
mcmc.print_summary()
``````

This produces the following traceback:

``````Traceback (most recent call last):
File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 37, in <module>
mcmc.run(seeds[1], y=train)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 593, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 381, in _single_chain_mcmc
init_state = self.sampler.init(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 706, in init
init_params = self._init_state(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 652, in _init_state
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/util.py", line 676, in initialize_model
substituted_model(*model_args, **model_kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 24, in model
_, y_predictions = scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 305, in scan_wrapper
last_carry, (pytree_trace, ys) = lax.scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 278, in scan
return tree_unflatten(out_tree, out)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/tree_util.py", line 71, in tree_unflatten
return treedef.unflatten(leaves)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 153, in tree_unflatten
return cls(**dict(zip(sorted(cls.arg_constraints.keys()), params)))
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 99, in __call__
return super().__call__(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/continuous.py", line 1701, in __init__
super(Normal, self).__init__(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 177, in __init__
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Normal distribution got invalid loc parameter.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/contextlib.py", line 131, in __exit__
self.gen.throw(type, value, traceback)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 80, in validation_enabled
yield
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/util.py", line 676, in initialize_model
substituted_model(*model_args, **model_kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 24, in model
_, y_predictions = scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 305, in scan_wrapper
last_carry, (pytree_trace, ys) = lax.scan(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 153, in tree_unflatten
return cls(**dict(zip(sorted(cls.arg_constraints.keys()), params)))
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 99, in __call__
return super().__call__(*args, **kwargs)
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/continuous.py", line 1701, in __init__
super(Normal, self).__init__(
File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 177, in __init__
raise ValueError(
ValueError: Normal distribution got invalid loc parameter.
``````

The key line would be `ValueError: Normal distribution got invalid loc parameter.`. From debugging it appears that the loc parameter diverges to infinity. This occurs before the chain becomes to warmup. Also I think the loc parameter in question must be the `mean`, but I don’t see why this would diverge since it’s the product of two normally distributed variabes: `theta` and `y_pred - mean`, which both have finite scale parameters.

When looking at a single series this problem can be easily fixed by changing the random seed or using a more informative prior, since `theta = numpyro.sample("theta", dist.Normal(0.0, 70.0))` gives a pretty wide range. But the more series I include the less reliable these solutions become. For example with `n_series=50` and

``````with numpyro.plate("plate", n_series):
theta = numpyro.sample("theta", dist.Normal(0.0, 1.0))
``````

I cannot find a seed which allows the MCMC sampling to be successful. So even though for a single series the problem can be sidestepped, I think for many series I will need to solve it.

Does anybody have any advice on this issue? Would be glad to hear it.

We provide various initialization strategy which might be helpful to your case. For example,

• init_to_feasible
• init_to_value(values=…)

Regarding the explosion, you can derive an analytical solution for the mean assuming y=0 - to see that the mean magnitude keeps increasing when |theta| > 1

1 Like

Thank you. This is very interesting. I’m learning a lot about how initialisation occurs.

Since I am sampling `y_pred` from a normal distribution about the predicted mean, I had assumed that `y_pred-mean` would also be normally distributed with a finite variance. But if I understand correctly, since I am supplying observed data `y_pred` is taken to be whatever I supply. So there is no guarantee that `y_pred - mean` will remain finite if `abs(theta)>1`.

It seems that any initialisation strategy provides a sufficiently small value of `theta` will be successful. E.g. `nuts_kernel = NUTS(model, init_strategy=init_to_value(values={"theta": 0.0}))`.

However, so far I find that any dataset with `theta` greater than 1 cannot be fitted by this model. For example we could create the following synthetic dataset:

``````noise = random.normal(shape=(2000, n_series), key=seeds[0])
noise = noise.at[0].set(0.0)
train = noise[1:] + 2*noise[:-1]
``````

The appropriate value of `theta` in this case is 2 but the model cannot fit it. Even though the model is correct for this data it’s also very unstable and cannot get the correct answer. If I run the following code:

``````import numpyro
from numpyro.infer import MCMC, NUTS, init_to_sample, init_to_feasible, init_to_value
import numpyro.distributions as dist
from jax import random
from numpyro.contrib.control_flow import scan
import jax.numpy as jnp

def model(y):
n_series = y.shape[1]
with numpyro.plate("plate", n_series):
theta = numpyro.sample("theta", dist.Normal(0.0, 1.0))

def transition_fn(carry, t_idx):
epsilon_prev = carry
mean = theta * epsilon_prev
y_pred = numpyro.sample("y_pred", dist.Normal(mean, 1.0))
carry_new = y_pred - mean
return carry_new, y_pred

time_indices = jnp.arange(0, y.shape[0])
init_carry = jnp.empty(shape=(n_series), dtype=float)
with numpyro.handlers.condition(data={"y_pred": y}):
_, y_predictions = scan(
transition_fn,
init_carry,
time_indices,
)

rng_key = random.PRNGKey(12)
seeds = random.split(rng_key, 10)
n_series = 1
nuts_kernel = NUTS(model, init_strategy=init_to_value(values={"theta": 0.0}))

noise = random.normal(shape=(200, n_series), key=seeds[0])
noise = noise.at[0].set(0.0)
train = noise[1:] + 2*noise[:-1]

mcmc = MCMC(nuts_kernel, num_samples=40000, num_warmup=20000, num_chains=4)
mcmc.run(seeds[1], y=train)
mcmc.print_summary()
``````

I get the following output:

``````                mean       std    median      5.0%     95.0%     n_eff     r_hat
theta      0.44      0.03      0.44      0.39      0.50  57943.36      1.00

Number of divergences: 0
``````

There is no indication from `n_eff` or `r_hat` that a problem occurred, but the result for `theta` is wrong. Also if I use `nuts_kernel = NUTS(model, init_strategy=init_to_value(values={"theta": 2.0}))` (i.e. I initialise using the correct answer), then I get the the `ValueError: Normal distribution got invalid loc parameter.` again.

Fortunately this example is quite contrived and unlikely to occur in realistic data. In fact if we include a scale parameter sigma for the `y_pred` distribution

``````y_pred = numpyro.sample("y_pred", dist.Normal(mean, sigma))
``````

then we can show that sigma = 1, theta = 2 (the supplied data) is equivalent to sigma = 2, theta = 0.5 (which is what the model returns if sigma is introduced):

``````             mean       std    median      5.0%     95.0%     n_eff     r_hat
sigma      1.95      0.09      1.95      1.79      2.09   1019.56      1.00
theta      0.46      0.06      0.46      0.37      0.56    786.29      1.00

Number of divergences: 0
``````

But this does make me wonder if there are other more realistic situations where we can write correct but very unstable models which return seemingly accurate results (large n_eff values, r_hat=1, small std, 0 divergences) that are actually wrong.

I guess this shows we need to be careful and understand our models and data well.

Anyway, thanks for the initialisation suggestions!

Hmm, this seems like a bug. Could you make a github issue for this? I will take a look later this week.

Thank you. I would be glad if you could look into it further. Here is the issue: https://github.com/pyro-ppl/numpyro/issues/1453