Exploding loc parameter in MA(1) time series model

Hello. I’m experimenting with an MA(1) model for time series analysis, but I’m facing an error and I would like to ask for some advice.

I am running the following code:

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample
import numpyro.distributions as dist
from jax import random
from numpyro.contrib.control_flow import scan
import jax.numpy as jnp


def model(y):

    n_series = y.shape[1]
    with numpyro.plate("plate", n_series):
        theta = numpyro.sample("theta", dist.Normal(0.0, 70.0))
    
    def transition_fn(carry, t_idx):
        epsilon_prev = carry
        mean = theta*epsilon_prev
        y_pred = numpyro.sample("y_pred", dist.Normal(mean, 1.0))
        carry_new = y_pred - mean
        return carry_new, y_pred
    
    time_indices = jnp.arange(0, y.shape[0])
    init_carry = jnp.empty(shape=(n_series), dtype=float)
    with numpyro.handlers.condition(data={"y_pred": y}):
        _, y_predictions = scan(
            transition_fn,
            init_carry,
            time_indices,
        )
       
    
rng_key = random.PRNGKey(12)
seeds = random.split(rng_key, 10)
n_series = 1
nuts_kernel = NUTS(model, init_strategy=init_to_sample)
train = random.normal(shape=(200,n_series), key=seeds[0])
mcmc = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000, num_chains=1)
mcmc.run(seeds[1], y=train)
mcmc.print_summary()

This produces the following traceback:

Traceback (most recent call last):
  File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 37, in <module>
    mcmc.run(seeds[1], y=train)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 593, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 381, in _single_chain_mcmc
    init_state = self.sampler.init(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 706, in init
    init_params = self._init_state(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 652, in _init_state
    init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/util.py", line 676, in initialize_model
    substituted_model(*model_args, **model_kwargs)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 24, in model
    _, y_predictions = scan(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
    msg = apply_stack(initial_msg)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 53, in apply_stack
    default_process_message(msg)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 28, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 305, in scan_wrapper
    last_carry, (pytree_trace, ys) = lax.scan(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 278, in scan
    return tree_unflatten(out_tree, out)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/jax/_src/tree_util.py", line 71, in tree_unflatten
    return treedef.unflatten(leaves)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 153, in tree_unflatten
    return cls(**dict(zip(sorted(cls.arg_constraints.keys()), params)))
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 99, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/continuous.py", line 1701, in __init__
    super(Normal, self).__init__(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 177, in __init__
    raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Normal distribution got invalid loc parameter.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/contextlib.py", line 131, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 80, in validation_enabled
    yield
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/infer/util.py", line 676, in initialize_model
    substituted_model(*model_args, **model_kwargs)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/paul/.config/JetBrains/PyCharmCE2022.1/scratches/test2.py", line 24, in model
    _, y_predictions = scan(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
    msg = apply_stack(initial_msg)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 53, in apply_stack
    default_process_message(msg)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/primitives.py", line 28, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/contrib/control_flow/scan.py", line 305, in scan_wrapper
    last_carry, (pytree_trace, ys) = lax.scan(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 153, in tree_unflatten
    return cls(**dict(zip(sorted(cls.arg_constraints.keys()), params)))
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 99, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/continuous.py", line 1701, in __init__
    super(Normal, self).__init__(
  File "/home/paul/anaconda3/envs/bayes/lib/python3.8/site-packages/numpyro/distributions/distribution.py", line 177, in __init__
    raise ValueError(
ValueError: Normal distribution got invalid loc parameter.

The key line would be ValueError: Normal distribution got invalid loc parameter.. From debugging it appears that the loc parameter diverges to infinity. This occurs before the chain becomes to warmup. Also I think the loc parameter in question must be the mean, but I don’t see why this would diverge since it’s the product of two normally distributed variabes: theta and y_pred - mean, which both have finite scale parameters.

When looking at a single series this problem can be easily fixed by changing the random seed or using a more informative prior, since theta = numpyro.sample("theta", dist.Normal(0.0, 70.0)) gives a pretty wide range. But the more series I include the less reliable these solutions become. For example with n_series=50 and

with numpyro.plate("plate", n_series):
        theta = numpyro.sample("theta", dist.Normal(0.0, 1.0))

I cannot find a seed which allows the MCMC sampling to be successful. So even though for a single series the problem can be sidestepped, I think for many series I will need to solve it.

Does anybody have any advice on this issue? Would be glad to hear it.

We provide various initialization strategy which might be helpful to your case. For example,

  • init_to_feasible
  • init_to_uniform(radius=0.1)
  • init_to_value(values=…)

Regarding the explosion, you can derive an analytical solution for the mean assuming y=0 - to see that the mean magnitude keeps increasing when |theta| > 1

1 Like

Thank you. This is very interesting. I’m learning a lot about how initialisation occurs.

Since I am sampling y_pred from a normal distribution about the predicted mean, I had assumed that y_pred-mean would also be normally distributed with a finite variance. But if I understand correctly, since I am supplying observed data y_pred is taken to be whatever I supply. So there is no guarantee that y_pred - mean will remain finite if abs(theta)>1.

It seems that any initialisation strategy provides a sufficiently small value of theta will be successful. E.g. nuts_kernel = NUTS(model, init_strategy=init_to_value(values={"theta": 0.0})).

However, so far I find that any dataset with theta greater than 1 cannot be fitted by this model. For example we could create the following synthetic dataset:

noise = random.normal(shape=(2000, n_series), key=seeds[0])
noise = noise.at[0].set(0.0)
train = noise[1:] + 2*noise[:-1]

The appropriate value of theta in this case is 2 but the model cannot fit it. Even though the model is correct for this data it’s also very unstable and cannot get the correct answer. If I run the following code:

import numpyro
from numpyro.infer import MCMC, NUTS, init_to_sample, init_to_feasible, init_to_value
import numpyro.distributions as dist
from jax import random
from numpyro.contrib.control_flow import scan
import jax.numpy as jnp


def model(y):
    n_series = y.shape[1]
    with numpyro.plate("plate", n_series):
        theta = numpyro.sample("theta", dist.Normal(0.0, 1.0))

    def transition_fn(carry, t_idx):
        epsilon_prev = carry
        mean = theta * epsilon_prev
        y_pred = numpyro.sample("y_pred", dist.Normal(mean, 1.0))
        carry_new = y_pred - mean
        return carry_new, y_pred

    time_indices = jnp.arange(0, y.shape[0])
    init_carry = jnp.empty(shape=(n_series), dtype=float)
    with numpyro.handlers.condition(data={"y_pred": y}):
        _, y_predictions = scan(
            transition_fn,
            init_carry,
            time_indices,
        )

rng_key = random.PRNGKey(12)
seeds = random.split(rng_key, 10)
n_series = 1
nuts_kernel = NUTS(model, init_strategy=init_to_value(values={"theta": 0.0}))

noise = random.normal(shape=(200, n_series), key=seeds[0])
noise = noise.at[0].set(0.0)
train = noise[1:] + 2*noise[:-1]

mcmc = MCMC(nuts_kernel, num_samples=40000, num_warmup=20000, num_chains=4)
mcmc.run(seeds[1], y=train)
mcmc.print_summary()

I get the following output:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     theta      0.44      0.03      0.44      0.39      0.50  57943.36      1.00

Number of divergences: 0

There is no indication from n_eff or r_hat that a problem occurred, but the result for theta is wrong. Also if I use nuts_kernel = NUTS(model, init_strategy=init_to_value(values={"theta": 2.0})) (i.e. I initialise using the correct answer), then I get the the ValueError: Normal distribution got invalid loc parameter. again.

Fortunately this example is quite contrived and unlikely to occur in realistic data. In fact if we include a scale parameter sigma for the y_pred distribution

y_pred = numpyro.sample("y_pred", dist.Normal(mean, sigma))

then we can show that sigma = 1, theta = 2 (the supplied data) is equivalent to sigma = 2, theta = 0.5 (which is what the model returns if sigma is introduced):

             mean       std    median      5.0%     95.0%     n_eff     r_hat
  sigma      1.95      0.09      1.95      1.79      2.09   1019.56      1.00
  theta      0.46      0.06      0.46      0.37      0.56    786.29      1.00

Number of divergences: 0

But this does make me wonder if there are other more realistic situations where we can write correct but very unstable models which return seemingly accurate results (large n_eff values, r_hat=1, small std, 0 divergences) that are actually wrong.

I guess this shows we need to be careful and understand our models and data well.

Anyway, thanks for the initialisation suggestions!

Hmm, this seems like a bug. Could you make a github issue for this? I will take a look later this week.

Thank you. I would be glad if you could look into it further. Here is the issue: https://github.com/pyro-ppl/numpyro/issues/1453