I am using dist.Dirichlet() to model probabilities but when looking at the retrodictions after fitting they are not as I expect. I suspect I am doing something fundamentally wrong. Here is a simple example with 5 groups where the first column has very low probability and last column has high.
import numpy as np import jax import jax.numpy as jnp import jax.random as random import numpyro from numpyro import distributions as dist from numpyro.infer import NUTS, MCMC, Predictive import arviz as az numpyro.set_host_device_count(1) N_ = 100 X = jax.nn.softmax(np.random.normal(np.arange(-4, 4.1, step=2), scale=2., size=(N_, 5)), axis=-1) def model1(X=None): with numpyro.plate('plate_people', 5): a = numpyro.sample('a', dist.Normal(0, 5.)) b = numpyro.deterministic('b', jax.nn.softplus(a)) obs = numpyro.sample('obs', dist.Dirichlet(b), obs=X) hmc = MCMC(NUTS(model1, target_accept_prob=0.9, ), num_chains=4, num_warmup=1000, num_samples=1000, progress_bar=True, chain_method='sequential') hmc.run(random.PRNGKey(32), X=X) ppc = Predictive(model1, hmc.get_samples())(random.PRNGKey(48), ) idata = az.from_numpyro(hmc, posterior_predictive=ppc) print(np.asarray(X).mean(0).round(3)) print(idata.posterior_predictive['obs'].mean(('chain', 'draw')).values.round(3))
Looking just at the first column, in the data X both the mean and 90th percentile are less than 0.005.
But in the predictions mean is 0.05 and 90th percentile is 0.15
How can I capture this in the model? I’ve tried increasing the number of rows (N) but it doesn’t have any impact.