We have made an AR model in pyro and used it to impute missing values in a time series data set. The implementation below seems to work, however we found, that the ordering of the definition of the variables “noises” and “z_prev” made a huge difference, which we did not expect. If “noises” is defined before “z_prev”, our model is unable to learn anything.
The details of our dataset etc. is not that important, however we were confused about why the ordering mattered, as “noises” and “z_prev” does not seem to depend on each other.
def model(T, T_forecast, obs1=None, ix_mis1=None, ix_obs1=None, obs2=None, ix_mis2=None, ix_obs2=None): #AR Parameters beta = numpyro.sample("beta", dist.Normal(jnp.zeros(2),jnp.ones(2))) #Noises in tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2))) sigma = numpyro.sample("sigma", dist.HalfCauchy(1.)) L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(dimension= 2, concentration=10.)) Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega) z_prev = numpyro.sample(name="z_prev", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2))) noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast-1,)) """ Propagate the dynamics forward using jax.lax.scan """ carry = (beta, z_prev, tau) z_collection = [z_prev] carry_updated, z_collection = lax.scan(f=f, init=carry, xs = noises, length=T+T_forecast-1) """ Sample the observed y (y_obs) and missing y (y_mis) """ y_obs1 = numpyro.sample('y_obs1', dist.Normal(z_collection[ix_obs1,0], sigma), obs=obs1) y_obs2 = numpyro.sample('y_obs2', dist.Normal(z_collection[ix_obs2,1], sigma), obs=obs2) y_mis1 = numpyro.sample('y_mis1', dist.Normal(z_collection[ix_mis1,0], sigma)) y_mis2 = numpyro.sample('y_mis2', dist.Normal(z_collection[ix_mis2,1], sigma)) return y_obs1, y_obs2 ´´´