If I read Hierarchical Bayesian Model with data of varying length - #6 by fehiepsi correctly, obs_mask
doesn’t work correctly with MultivariateNormal
. Thus I instead end up with code like:
import numpyro as ny
import numpyro.distributions as dist
import numpy as np
from jax import random
from jax import numpy as jnp
def model(shape, data):
nan_idx = np.nonzero(np.isnan(data))
width = shape[1]
y = ny.sample("y", dist.Normal(loc=np.zeros(width), scale=jnp.ones(width)))
with ny.plate("data", shape[0], dim=-2):
imputed_x = jnp.squeeze(ny.sample("x_imp", dist.MultivariateNormal(loc=y, covariance_matrix=jnp.identity(width)).mask(False)))
observed_and_imputed_x = ops.index_update(data, nan_idx, imputed_x[nan_idx])
x = ny.sample("x", dist.MultivariateNormal(loc=y + 2, covariance_matrix=jnp.identity(width)), obs=observed_and_imputed_x)
mcmc = ny.infer.MCMC(ny.infer.NUTS(model=model), num_warmup=500, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(1), data=np.random.random((100,8)), shape=(100,8))
mcmc.print_summary()
This code seems to work but it’s also dramatically slower than the equivalent model without imputation. Based on the discussion here, I assume this is because it ends up creating a giant graph with one node for each data point. This doesn’t seem strictly necessary since we really only want one additional node for each data point actually in need of imputation (i.e. one per NaN) instead of each data point simpliciter (i.e. one per batch size X event size). And for each MCMC run, we know ahead of time how many imputations we need and where. I’d like to do something like:
imputed_x = ny.sample("x_imp", lambda key: dist.MultivariateNormal(loc=y, covariance_matrix=jnp.identity(width)).mask(False)(key)[nan_idx])
so only the actual imputations get registered. But this code doesn’t quite work. Is this approach generally sensible? Any thoughts on the best way to implement it? A Transform
?