Dirichlet with some excluded

Hello

I would like to use a Dirichlet distribution where some of the entries in each observation are not included. It’s easiest to look through the example.

size = 32, 8
X = np.random.rand(*size) * np.arange(1, size[1]+1)
mask = np.random.rand(*size)>0.2
X = X * mask
X = X / X.sum(1)[:, None]
eps = 1e-6
X = np.clip(X, eps, 1-eps)

def model(X=None):
    unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
    unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
    unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
    unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
    unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
    obs = numpyro.sample('obs', dist.Dirichlet(unit), obs=X)
hmc = MCMC(NUTS(model, target_accept_prob=0.95), num_warmup=1000, num_samples=1000, num_chains=4)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), X=X)
ppc = Predictive(model, hmc.get_samples())(random.PRNGKey(1), )
idata = az.from_numpyro(hmc, posterior_predictive=ppc)
print(az.summary(idata, var_names=['~pre', '~unconstrained'], filter_vars='regex'))

How can I use mask in the model?

One bulky approach would be to use a separate observed dist.Dirichlet for each row where I could manually remove the entries where mask=False but I am hoping there is a cleaner method you can point me to.

def model2(mask, X=None):
    unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
    unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
    unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
    unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
    unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
    for j, (x, mask_row) in enumerate(zip(X, mask)):
        this_loc = unit[mask_row]
        this_X = x[mask_row]
        obs = numpyro.sample(f'obs{j}', dist.Dirichlet(this_loc), obs=this_X)

@nkaimcaudle I guess you can use the mask method: dist.Dirichlet(this_loc).mask(mask).

I would like something that doesn’t require me to create the loop over the rows. The mask method with loop does work but it is many times slower than the non-loop.

I tried the below but it fails

def model3(mask, X=None):
    unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
    unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
    unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
    unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
    unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
    obs = numpyro.sample('obs', dist.Dirichlet(unit).mask(mask), obs=X)
hmc = MCMC(NUTS(model3, target_accept_prob=0.95), num_warmup=1000, num_samples=1000, num_chains=4)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), X=X, mask=mask)
ppc = Predictive(model, hmc.get_samples())(random.PRNGKey(1), )
idata3 = az.from_numpyro(hmc, posterior_predictive=ppc)
az.summary(idata3, var_names=['~pre', '~unconstrained'], filter_vars='regex')

the error message is

ValueError: Incompatible shapes for broadcasting: ((100, 8), (1, 100))

if I change the likelihood to this then it does work, but gives incorrect results

obs = numpyro.sample('obs', dist.Dirichlet(unit).mask(mask.T), obs=X)

@nkaimcaudle Usually, you don’t need to use for loop. Could you print out the shapes of unit, mask, X so I can have a better idea of what is going on?

You can use the code in my first post to generate fake data, X and mask are defined there. unit is defined within the model code

Oh, we have mask.shape == X.shape == (32, 8). I think it will not work for Dirichlet distribution. Here we have (32,) dirichlet samples, while the mask is (32, 8). .mask() only applies for batch dimensions, i.e. (32,) in this case, so the error happens. IIUC, you have a collection of Dirichlet samples, which have different shapes, and you want to construct a model for those samples. This is tricky. I don’t know if we have an easy solution for this.

Yes, your understanding is correct. There are 8 members of the group, on each project 5 of them are randomly selected and I want to model the proportion of the project’s output each member contributes.

If after masking out, the shape is constant ((5,) in this case), then you can use jnp.take_along_axis for the job:

numpyro.sample('obs', dist.Dirichlet(jnp.take_along_axis(unit, mask_indices, axis=-1), obs=jnp.take_along_axis(X, mask_indices, axis=-1))

Here mask_indices (which has shape (30, 5)) is the positive indices of mask (which has shape (30, 8)).

Thanks a lot, that may be suitable after some refactoring.

Let me give it a try. It’s good to know about that function anyway