Potentially a very naive question as I’m still coming to grips with (num)pyro, but how does one handle missingness in the data within a mixture model? From the tutorials, I would have thought that since we have a density to describe the data in the model that an explicit handling of the missing data would not be required. However when I run the below model on a dataset with some missingness I get a runtime error, but when I run on the same data except the missingness is removed I can sample the model with no problems. And given the imputation is conditional on the assignment I am not sure how to elegantly impute values.
If someone can explain how to handle this it would be much appreciated! Thank you.
Error:
RuntimeError: Cannot find valid initial parameters. Please check your model again.
Model:
@config_enumerate
def gaussianMixtureModel(data, K):
N, P = data.shape
# Global variables.
weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
with numpyro.plate("components", K, dim=-2):
with numpyro.plate("measurements", P, dim=-1):
scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
locs = numpyro.sample("locs", dist.Normal(0.0, 10.0))
with numpyro.plate("data", N, dim=-2) as n:
# Local variables.
assignment = numpyro.sample("assignment", dist.Categorical(weights))
with numpyro.plate("measurements", P, dim=-1) as p:
if data is None:
data = numpyro.sample("obs", dist.Normal(Vindex(locs)[..., assignment, p], Vindex(scale)[..., assignment, p])).mask(False)
numpyro.sample("obs", dist.Normal(Vindex(locs)[..., assignment, p], Vindex(scale)[..., assignment, p]), obs=data)