Fitting posterior on params of Multivariate normal given partially-missing data

I am trying to recover the parameters of a 5-dimensional MultivariateNormal distribution. The catch is that instead of being given samples drawn from that distribution, I am given samples where some of the components of each vector are “masked off”.

I.e. I can generate example data like this:

N_dim = 5
N_samples = 1_000
off_diag_corr = 0.25
p_observe = 0.5

mu = np.linspace(2.5, 4.5, N_dim)

corr = np.ones((N_dim, N_dim)) * off_diag_corr
corr[np.diag_indices(N_dim, 2)] = 1.0
vols = np.linspace(0.5, 1.5, N_dim)
covar = np.diag(vols).dot(corr).dot(np.diag(vols))

samples = np.random.multivariate_normal(mu, cov=covar, size=N_samples)
inclusions = np.random.binomial(1, p=p_observe, size=(N_samples, N_dim))

# put in some gibberish where we do *not* have inclusions in our data set
samples = np.where(inclusions, samples, 1e6)

Basically I’m saying I get these samples, but I’m going to force myself to only “see” the portion of those samples where the corresponding inclusions flags are 1.

My instinct was to do something like this:

def model(samples, inclusions):
    with numpyro.plate('dim', samples.shape[1]):
        mu = numpyro.sample('mu', dist.Normal(0, 15))
        s = numpyro.sample('s', dist.HalfNormal(5))

    correl_chol = numpyro.sample('correl_chol', dist.LJKCholesky(samples.shape[1], concentration=1.0))

    with numpyro.plate('observations', samples.shape[0]):
        samples_imputed = numpyro.sample('samples_imputed, dist.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol))
        samples_merged = numpyro.deterministic('samples_merged', jnp.where(inclusions, samples, samples_imputed)
        numpyro.sample('samples', numpyro.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol), obs=samples)

This compiles, but runs horribly slowly even at the modest scale above. I’m a bit lost. Is this a reasonable approach conceptually? Is it reasonable computationally?

For inference, you might want to skip this sample statement.

I’m not sure I follow. Without that line, where will I get something to “fill in” for the missing values in the subsequent line?

You can remove that line too. Alternatively, you can do

def model(samples, inclusions=None):
    ...
    with numpyro.plate('observations', samples.shape[0]):
        if inclusions is not None:
            samples_imputed = numpyro.sample('samples_imputed, dist.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol))
            samples_merged = numpyro.deterministic('samples_merged', jnp.where(inclusions, samples, samples_imputed)
        ...

and run MCMC/SVI with inclusions=None. After getting the posterior samples, you can run Predictive with some inclusions.