Hi
We have made an AR model in pyro and used it to impute missing values in a time series data set. The implementation below seems to work, however we found, that the ordering of the definition of the variables “noises” and “z_prev” made a huge difference, which we did not expect. If “noises” is defined before “z_prev”, our model is unable to learn anything.
The details of our dataset etc. is not that important, however we were confused about why the ordering mattered, as “noises” and “z_prev” does not seem to depend on each other.
def model(T, T_forecast, obs1=None, ix_mis1=None, ix_obs1=None, obs2=None, ix_mis2=None, ix_obs2=None):
#AR Parameters
beta = numpyro.sample("beta", dist.Normal(jnp.zeros(2),jnp.ones(2)))
#Noises in
tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
sigma = numpyro.sample("sigma", dist.HalfCauchy(1.))
L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(dimension= 2, concentration=10.))
Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega)
z_prev = numpyro.sample(name="z_prev", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)))
noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast-1,))
"""
Propagate the dynamics forward using jax.lax.scan
"""
carry = (beta, z_prev, tau)
z_collection = [z_prev]
carry_updated, z_collection = lax.scan(f=f, init=carry, xs = noises, length=T+T_forecast-1)
"""
Sample the observed y (y_obs) and missing y (y_mis)
"""
y_obs1 = numpyro.sample('y_obs1', dist.Normal(z_collection[ix_obs1,0], sigma), obs=obs1)
y_obs2 = numpyro.sample('y_obs2', dist.Normal(z_collection[ix_obs2,1], sigma), obs=obs2)
y_mis1 = numpyro.sample('y_mis1', dist.Normal(z_collection[ix_mis1,0], sigma))
y_mis2 = numpyro.sample('y_mis2', dist.Normal(z_collection[ix_mis2,1], sigma))
return y_obs1, y_obs2
´´´