Improving performance of imputation by pruning

If I read Hierarchical Bayesian Model with data of varying length - #6 by fehiepsi correctly, obs_mask doesn’t work correctly with MultivariateNormal. Thus I instead end up with code like:

import numpyro as ny
import numpyro.distributions as dist
import numpy as np
from jax import random
from jax import numpy as jnp

def model(shape, data):
    nan_idx = np.nonzero(np.isnan(data))
    width = shape[1]
    y = ny.sample("y", dist.Normal(loc=np.zeros(width), scale=jnp.ones(width)))
    with ny.plate("data", shape[0], dim=-2):
        imputed_x = jnp.squeeze(ny.sample("x_imp", dist.MultivariateNormal(loc=y, covariance_matrix=jnp.identity(width)).mask(False)))
        observed_and_imputed_x = ops.index_update(data, nan_idx, imputed_x[nan_idx])
        x = ny.sample("x", dist.MultivariateNormal(loc=y + 2, covariance_matrix=jnp.identity(width)), obs=observed_and_imputed_x)

mcmc = ny.infer.MCMC(ny.infer.NUTS(model=model), num_warmup=500, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(1), data=np.random.random((100,8)), shape=(100,8))
mcmc.print_summary()

This code seems to work but it’s also dramatically slower than the equivalent model without imputation. Based on the discussion here, I assume this is because it ends up creating a giant graph with one node for each data point. This doesn’t seem strictly necessary since we really only want one additional node for each data point actually in need of imputation (i.e. one per NaN) instead of each data point simpliciter (i.e. one per batch size X event size). And for each MCMC run, we know ahead of time how many imputations we need and where. I’d like to do something like:

imputed_x = ny.sample("x_imp", lambda key: dist.MultivariateNormal(loc=y, covariance_matrix=jnp.identity(width)).mask(False)(key)[nan_idx])

so only the actual imputations get registered. But this code doesn’t quite work. Is this approach generally sensible? Any thoughts on the best way to implement it? A Transform?

Hi @cole_haus Given N data points X and each point X[i] has m features, there are two kinds of missing data:

  • Some data point X[i] is missing, you can use obs_mask or the usual imputation strategy for this.
  • Some features of a data point X[i] is missing. This is tricky. We need some sort of gaussian conditional here. The conditional logic is different for each data point, so we might need to use a for loop here, which is expensive.

Regarding the performance, you might want to add some print statement to the model to see if you get expected shapes. One thing is if you use MVN, the plate dimension is -1, not -2. We consider each point X[i] is an element in that plate. Please see Tensor shapes in Pyro — Pyro Tutorials 1.7.0 documentation tutorial for more information of shape senmatics in Pyro.

Ahh, yeah, I was using a plate dimension of -1 and accidentally left it at -2 after fiddling around. Thanks for catching that.

Unfortunately, in my case, I have the more complicated scenario of some features of the data point missing. I’ll look into the Gaussian conditional approach.

Re: performance, I was actually mistaken about the problem. Looking again, it seems like the ops.index_update is the primary culprit. I suppose there’s no real way around that given the design of JAX.

Thanks again!