Hello
I would like to use a Dirichlet distribution where some of the entries in each observation are not included. It’s easiest to look through the example.
size = 32, 8
X = np.random.rand(*size) * np.arange(1, size[1]+1)
mask = np.random.rand(*size)>0.2
X = X * mask
X = X / X.sum(1)[:, None]
eps = 1e-6
X = np.clip(X, eps, 1-eps)
def model(X=None):
unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
obs = numpyro.sample('obs', dist.Dirichlet(unit), obs=X)
hmc = MCMC(NUTS(model, target_accept_prob=0.95), num_warmup=1000, num_samples=1000, num_chains=4)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), X=X)
ppc = Predictive(model, hmc.get_samples())(random.PRNGKey(1), )
idata = az.from_numpyro(hmc, posterior_predictive=ppc)
print(az.summary(idata, var_names=['~pre', '~unconstrained'], filter_vars='regex'))
How can I use mask
in the model?
One bulky approach would be to use a separate observed dist.Dirichlet for each row where I could manually remove the entries where mask=False but I am hoping there is a cleaner method you can point me to.
def model2(mask, X=None):
unit_mean = numpyro.sample('unit_mean', dist.Normal(0, 1))
unit_scale = numpyro.sample('unit_scale', dist.Exponential(1))
unit_pre = numpyro.sample('unit_pre', dist.Normal(0., 1.5).expand([size[1]]))
unit = numpyro.deterministic('unit_unconstrained', unit_mean + unit_scale*unit_pre)
unit = numpyro.deterministic('unit', jax.nn.softplus(unit))
for j, (x, mask_row) in enumerate(zip(X, mask)):
this_loc = unit[mask_row]
this_X = x[mask_row]
obs = numpyro.sample(f'obs{j}', dist.Dirichlet(this_loc), obs=this_X)