Ordering of variables affects learning in AR model

Hi
We have made an AR model in pyro and used it to impute missing values in a time series data set. The implementation below seems to work, however we found, that the ordering of the definition of the variables “noises” and “z_prev” made a huge difference, which we did not expect. If “noises” is defined before “z_prev”, our model is unable to learn anything.
The details of our dataset etc. is not that important, however we were confused about why the ordering mattered, as “noises” and “z_prev” does not seem to depend on each other.

def model(T, T_forecast, obs1=None, ix_mis1=None, ix_obs1=None, obs2=None, ix_mis2=None, ix_obs2=None):

    #AR Parameters

    beta = numpyro.sample("beta", dist.Normal(jnp.zeros(2),jnp.ones(2)))   

    #Noises in 

    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))

    sigma = numpyro.sample("sigma", dist.HalfCauchy(1.))

    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(dimension= 2, concentration=10.))

    Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega)

    z_prev = numpyro.sample(name="z_prev", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)))

    noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower),   sample_shape=(T+T_forecast-1,))

    """

    Propagate the dynamics forward using jax.lax.scan

    """

    carry = (beta, z_prev, tau)

    z_collection = [z_prev]

    carry_updated, z_collection = lax.scan(f=f, init=carry, xs = noises, length=T+T_forecast-1)

    """

    Sample the observed y (y_obs) and missing y (y_mis)

    """

    y_obs1 = numpyro.sample('y_obs1', dist.Normal(z_collection[ix_obs1,0], sigma), obs=obs1)

    y_obs2 = numpyro.sample('y_obs2', dist.Normal(z_collection[ix_obs2,1], sigma), obs=obs2)

    y_mis1 = numpyro.sample('y_mis1', dist.Normal(z_collection[ix_mis1,0], sigma))

    y_mis2 = numpyro.sample('y_mis2', dist.Normal(z_collection[ix_mis2,1], sigma))

    

    return y_obs1, y_obs2
´´´

Hi @nikolasthuesen, this might be a serious issue. Could you confirm that model is unable to learn anything also happens with different random seeds?

I tried changing the random seed, but the model still wouldn’t learn, if the order was incorrect. If you want to look into the full thing, I’ve made a colab document:

1 Like

Thanks, @nikolasthuesen! I am worrying that there might be some issues with flatten/unflatten the dict of latent variables. I will look into this issue and get back to you.

1 Like

@nikolasthuesen I run your code locally but found that MCMC is not mixing with any order of z_prev and noises (r_hat is pretty large). I would recommend moving those y_mis* sites to the prediction (not inference) phase. To my knowledge, given a model p(y | x, theta), we use Bayesian imputation to impute missing values of x (i.e. we will define prior for x_mis and run MCMC to learn p(x_mis, theta | x_nomis, y). If we have missing observation, like your case, then after geting p(theta | x, y) using MCMC, you can generate prediction: p(y_mis | theta, x). FYI, if I remove y_mis1 and y_mis2, I see that MCMC converges.

This also applies for forecasting. We should not do forecasting and infer latent variables at the same time. There is no objective for those forecasting results, so that MCMC can learn something. Basically, if you put prior to x, without observation, MCMC will return the same prior for you.

def model():
    x = sample("x", dist.Normal(0, 1))

In your model, y_mis1 is simply

y_mis1 = z_collection[ix_mis1,0] + sigma * numpyro.sample('y_mis1_noise', dist.Normal(0, 1))

and MCMC will return the same prior Normal(0, 1) for you. Without reparameterization like that, it is pretty tricky for MCMC to do both jobs:

  • learn something depends on other latent variables
  • infer unnecessary-to-infer variables

Solution for the first issue is to do reparameterization as above. For the second issue, just simply move such inference to prediction step.

Hi @fehiepsi, I see what you mean and agree that, computationally, MCMC could find “easier life” at inference time with the suggestions you gave. On the other hand, I am not entirely convinved from a theoretical point of view. To me, the “bayesian way” of making data imputation is by treating missing observations as latent variables in the PGM. I tried implementing the above model in STAN (where missing observations are treated as latent variables) and the model seems able to successfully recover the posterior and imputing missing values at inference time… do you have any ideas on what could cause this behavior in Pyro?

If it can be of any help, here you can find the STAN implementation: Google Colab

Also, regarding this:

what is odd is that if you use the following ordering of the variables:

def model(T, T_forecast, obs1=None, ix_mis1=None, ix_obs1=None, obs2=None, ix_mis2=None, ix_obs2=None):

    # Define priors over beta, tau, sigma, z_1 (keep the shapes in mind)
    beta = numpyro.sample(name="beta", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)))
    tau = numpyro.sample(name="tau", fn=dist.HalfCauchy(scale=jnp.ones(2)))
    sigma = numpyro.sample(name="sigma", fn=dist.HalfCauchy(scale=.1))
    z_prev = numpyro.sample(name="z_1", fn=dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)))

    # Define LKJ prior
    L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(2, 10.))
    Sigma_lower = jnp.matmul(jnp.diag(jnp.sqrt(tau)), L_Omega) # lower cholesky factor of the covariance matrix
    noises = numpyro.sample("noises", fn=dist.MultivariateNormal(loc=jnp.zeros(2), scale_tril=Sigma_lower), sample_shape=(T+T_forecast-1,))
    
    # Propagate the dynamics forward using jax.lax.scan
    carry = (beta, z_prev, tau)
    z_collection = [z_prev]
    carry, zs_exp = lax.scan(f, carry, noises, T+T_forecast-1)
    z_collection = jnp.concatenate((jnp.array(z_collection), zs_exp), axis=0)

    # Sample the observed y (y_obs) and missing y (y_mis)
    numpyro.sample(name="y_mis1", fn=dist.Normal(loc=z_collection[ix_mis1, 0], scale=sigma), obs=None)
    numpyro.sample(name="y_obs1", fn=dist.Normal(loc=z_collection[ix_obs1, 0], scale=sigma), obs=obs1)
    numpyro.sample(name="y_mis2", fn=dist.Normal(loc=z_collection[ix_mis2, 1], scale=sigma), obs=None)
    numpyro.sample(name="y_obs2", fn=dist.Normal(loc=z_collection[ix_obs2, 1], scale=sigma), obs=obs2)

the model is able to successfully impute missing values (with very nice values of rhat)

Sorry for late response! (I have written the reply but needed to think more about the issue, then I forgot :frowning: )

I guess Stan treats variables that has no other variables depend on (like your forecasting values) as “predictive” variables, and after it runs MCMC over the remaining variables, it will run Predictive to get the samples of those forecasting values. Maybe I am wrong but I can’t think of another reason.

About swapping variables, I still can’t reproduce your observations. All of my runs give bad mixing results, so I’m not sure if there is an actual issue in NumPyro. Anyway, thank you for raising this interesting and probably serious issue! I wish that we have a more pyroic model (i.e. forecasting is treated differently, through a flag if forecast or something like that) so the issue can be tracked down easier.