I am trying to fit a mutivariate time series model in numpyro. The model is a state-space model with AR(1). The data consists of two time series with missing values. I used the non-centered parametrization and jax.lax.scan as in Using lax.scan for time series in NumPyro. I am doing inference using NUTS. The generative process and the code of the model are the following:
# Definition of the function for the jax.lax.scan def f(carry, noise_t): beta, z_prev = carry z_t = beta*z_prev + noise_t z_prev = z_t return (beta, z_prev), z_t def model(T, T_forecast, obs1=None, ix_mis1=None, ix_obs1=None, obs2=None, ix_mis2=None, ix_obs2=None): """ Prior for the betas, sigma and first state """ beta = numpyro.sample("beta", numpyro.distributions.Normal(jnp.zeros(2), jnp.ones(2))) sigma = numpyro.sample("sigma", numpyro.distributions.HalfCauchy(jnp.ones(2))) z_prev = numpyro.sample("z_prev", numpyro.distributions.Normal(jnp.zeros(2), jnp.ones(2))) """ Prior for the LKJ including a tau and the sampling noise """ tau = numpyro.sample("tau", numpyro.distributions.HalfCauchy(jnp.ones(2))) L_Omega = numpyro.sample("L_Omega", numpyro.distributions.LKJCholesky(dimension=2, concentration=1.)) Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega) noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast-1,)) """ Propagate the dynamics forward using jax.lax.scan """ carry = (beta, z_prev) z_collection = [z_prev] updated_carry, zs = lax.scan(f=f, init=carry, xs=noises, length=T+T_forecast-1) z_collection = jnp.concatenate((jnp.array(z_collection), zs), axis=0) """ Sample the observed y (y_obs) and missing y (y_mis) """ y1_obs = numpyro.sample("y1_obs", numpyro.distributions.Normal(z_collection[ix_obs1, 0], sigma), obs=obs1) y1_mis = numpyro.sample("y1_mis", numpyro.distributions.Normal(z_collection[ix_mis1, 0], sigma)) y1_pred = numpyro.sample("y1_pred", numpyro.distributions.Normal(z_collection[T:, 0], sigma)) y2_obs = numpyro.sample("y2_obs", numpyro.distributions.Normal(z_collection[ix_obs2, 1], sigma), obs=obs2) y2_mis = numpyro.sample("y2_mis", numpyro.distributions.Normal(z_collection[ix_mis2, 1], sigma)) y2_pred = numpyro.sample("y2_pred", numpyro.distributions.Normal(z_collection[T:, 1], sigma))
I have two questions about the model:
- Why does the model “learn” a distribution for each of the noises? I would expect the posterior of noises to be just standard gaussians. Almost all of them are, but some are clearly not. If the noise is “exogenous” why can’t I input just numpy samples from a standard gaussian?
- When I run the model without the predictions (y1_pred and y2_pred) the model runs relatively smoothly (still with the problem in 1) but I get sensible results), but whenever I include the forecasting part the model collapses. That is, the posterior become single values sampled over and over with n_eff=0.5 for all of them.