I am trying to recover the parameters of a 5-dimensional `MultivariateNormal`

distribution. The catch is that instead of being given samples drawn from that distribution, I am given samples where some of the components of each vector are “masked off”.

I.e. I can generate example data like this:

```
N_dim = 5
N_samples = 1_000
off_diag_corr = 0.25
p_observe = 0.5
mu = np.linspace(2.5, 4.5, N_dim)
corr = np.ones((N_dim, N_dim)) * off_diag_corr
corr[np.diag_indices(N_dim, 2)] = 1.0
vols = np.linspace(0.5, 1.5, N_dim)
covar = np.diag(vols).dot(corr).dot(np.diag(vols))
samples = np.random.multivariate_normal(mu, cov=covar, size=N_samples)
inclusions = np.random.binomial(1, p=p_observe, size=(N_samples, N_dim))
# put in some gibberish where we do *not* have inclusions in our data set
samples = np.where(inclusions, samples, 1e6)
```

Basically I’m saying I get these `samples`

, but I’m going to force myself to only “see” the portion of those samples where the corresponding `inclusions`

flags are 1.

My instinct was to do something like this:

```
def model(samples, inclusions):
with numpyro.plate('dim', samples.shape[1]):
mu = numpyro.sample('mu', dist.Normal(0, 15))
s = numpyro.sample('s', dist.HalfNormal(5))
correl_chol = numpyro.sample('correl_chol', dist.LJKCholesky(samples.shape[1], concentration=1.0))
with numpyro.plate('observations', samples.shape[0]):
samples_imputed = numpyro.sample('samples_imputed, dist.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol))
samples_merged = numpyro.deterministic('samples_merged', jnp.where(inclusions, samples, samples_imputed)
numpyro.sample('samples', numpyro.MultivariateNormal(mu, scale_tril=s[..., None] * correl_chol), obs=samples)
```

This compiles, but runs horribly slowly even at the modest scale above. I’m a bit lost. Is this a reasonable approach conceptually? Is it reasonable computationally?