# Fitting posterior on params of Multivariate normal given partially-missing data

I am trying to recover the parameters of a 5-dimensional `MultivariateNormal` distribution. The catch is that instead of being given samples drawn from that distribution, I am given samples where some of the components of each vector are “masked off”.

I.e. I can generate example data like this:

``````N_dim = 5
N_samples = 1_000
off_diag_corr = 0.25
p_observe = 0.5

mu = np.linspace(2.5, 4.5, N_dim)

corr = np.ones((N_dim, N_dim)) * off_diag_corr
corr[np.diag_indices(N_dim, 2)] = 1.0
vols = np.linspace(0.5, 1.5, N_dim)
covar = np.diag(vols).dot(corr).dot(np.diag(vols))

samples = np.random.multivariate_normal(mu, cov=covar, size=N_samples)
inclusions = np.random.binomial(1, p=p_observe, size=(N_samples, N_dim))

# put in some gibberish where we do *not* have inclusions in our data set
samples = np.where(inclusions, samples, 1e6)
``````

Basically I’m saying I get these `samples`, but I’m going to force myself to only “see” the portion of those samples where the corresponding `inclusions` flags are 1.

My instinct was to do something like this:

``````def model(samples, inclusions):
with numpyro.plate('dim', samples.shape[1]):
mu = numpyro.sample('mu', dist.Normal(0, 15))
s = numpyro.sample('s', dist.HalfNormal(5))

correl_chol = numpyro.sample('correl_chol', dist.LJKCholesky(samples.shape[1], concentration=1.0))

with numpyro.plate('observations', samples.shape[0]):
samples_imputed = numpyro.sample('samples_imputed, dist.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol))
samples_merged = numpyro.deterministic('samples_merged', jnp.where(inclusions, samples, samples_imputed)
numpyro.sample('samples', numpyro.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol), obs=samples)
``````

This compiles, but runs horribly slowly even at the modest scale above. I’m a bit lost. Is this a reasonable approach conceptually? Is it reasonable computationally?

For inference, you might want to skip this `sample` statement.

I’m not sure I follow. Without that line, where will I get something to “fill in” for the missing values in the subsequent line?

You can remove that line too. Alternatively, you can do

``````def model(samples, inclusions=None):
...
with numpyro.plate('observations', samples.shape[0]):
if inclusions is not None:
samples_imputed = numpyro.sample('samples_imputed, dist.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol))
samples_merged = numpyro.deterministic('samples_merged', jnp.where(inclusions, samples, samples_imputed)
...
``````

and run MCMC/SVI with `inclusions=None`. After getting the posterior samples, you can run `Predictive` with some `inclusions`.