# Dirichlet with some excluded

Hello

I would like to use a Dirichlet distribution where some of the entries in each observation are not included. It’s easiest to look through the example.

``````size = 32, 8
X = np.random.rand(*size) * np.arange(1, size[1]+1)
X = X / X.sum(1)[:, None]
eps = 1e-6
X = np.clip(X, eps, 1-eps)

def model(X=None):
unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
obs = numpyro.sample('obs', dist.Dirichlet(unit), obs=X)
hmc = MCMC(NUTS(model, target_accept_prob=0.95), num_warmup=1000, num_samples=1000, num_chains=4)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), X=X)
ppc = Predictive(model, hmc.get_samples())(random.PRNGKey(1), )
idata = az.from_numpyro(hmc, posterior_predictive=ppc)
print(az.summary(idata, var_names=['~pre', '~unconstrained'], filter_vars='regex'))
``````

How can I use `mask` in the model?

One bulky approach would be to use a separate observed dist.Dirichlet for each row where I could manually remove the entries where mask=False but I am hoping there is a cleaner method you can point me to.

``````def model2(mask, X=None):
unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
obs = numpyro.sample(f'obs{j}', dist.Dirichlet(this_loc), obs=this_X)``````

@nkaimcaudle I guess you can use the mask method: `dist.Dirichlet(this_loc).mask(mask)`.

I would like something that doesn’t require me to create the loop over the rows. The mask method with loop does work but it is many times slower than the non-loop.

I tried the below but it fails

``````def model3(mask, X=None):
unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
hmc = MCMC(NUTS(model3, target_accept_prob=0.95), num_warmup=1000, num_samples=1000, num_chains=4)
ppc = Predictive(model, hmc.get_samples())(random.PRNGKey(1), )
idata3 = az.from_numpyro(hmc, posterior_predictive=ppc)
az.summary(idata3, var_names=['~pre', '~unconstrained'], filter_vars='regex')
``````

the error message is

`ValueError: Incompatible shapes for broadcasting: ((100, 8), (1, 100))`

if I change the likelihood to this then it does work, but gives incorrect results

`obs = numpyro.sample('obs', dist.Dirichlet(unit).mask(mask.T), obs=X)`

@nkaimcaudle Usually, you don’t need to use for loop. Could you print out the shapes of `unit`, `mask`, `X` so I can have a better idea of what is going on?

You can use the code in my first post to generate fake data, `X` and `mask` are defined there. `unit` is defined within the model code

Oh, we have `mask.shape == X.shape == (32, 8)`. I think it will not work for Dirichlet distribution. Here we have `(32,)` dirichlet samples, while the mask is `(32, 8)`. `.mask()` only applies for batch dimensions, i.e. `(32,)` in this case, so the error happens. IIUC, you have a collection of Dirichlet samples, which have different shapes, and you want to construct a model for those samples. This is tricky. I don’t know if we have an easy solution for this.

Yes, your understanding is correct. There are 8 members of the group, on each project 5 of them are randomly selected and I want to model the proportion of the project’s output each member contributes.

If after masking out, the shape is constant (`(5,)` in this case), then you can use jnp.take_along_axis for the job:

``````numpyro.sample('obs', dist.Dirichlet(jnp.take_along_axis(unit, mask_indices, axis=-1), obs=jnp.take_along_axis(X, mask_indices, axis=-1))
``````

Here `mask_indices` (which has shape `(30, 5)`) is the positive indices of `mask` (which has shape `(30, 8)`).

Thanks a lot, that may be suitable after some refactoring.

Let me give it a try. It’s good to know about that function anyway