Missing data in a mixture model

Potentially a very naive question as I’m still coming to grips with (num)pyro, but how does one handle missingness in the data within a mixture model? From the tutorials, I would have thought that since we have a density to describe the data in the model that an explicit handling of the missing data would not be required. However when I run the below model on a dataset with some missingness I get a runtime error, but when I run on the same data except the missingness is removed I can sample the model with no problems. And given the imputation is conditional on the assignment I am not sure how to elegantly impute values.

If someone can explain how to handle this it would be much appreciated! Thank you.

Error:

RuntimeError: Cannot find valid initial parameters. Please check your model again.

Model:

@config_enumerate
def gaussianMixtureModel(data, K):
    N, P = data.shape

    # Global variables.
    weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
    with numpyro.plate("components", K, dim=-2):
        with numpyro.plate("measurements", P, dim=-1):
            scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
            locs = numpyro.sample("locs", dist.Normal(0.0, 10.0))

    with numpyro.plate("data", N, dim=-2) as n:
        # Local variables.
        assignment = numpyro.sample("assignment", dist.Categorical(weights))
        with numpyro.plate("measurements", P, dim=-1) as p:
            if data is None:
                data = numpyro.sample("obs", dist.Normal(Vindex(locs)[..., assignment, p], Vindex(scale)[..., assignment, p])).mask(False)
           numpyro.sample("obs", dist.Normal(Vindex(locs)[..., assignment, p], Vindex(scale)[..., assignment, p]), obs=data)

presumably you want to enclose this sample statement in an appropriate mask handler context?

1 Like