I am using dist.Dirichlet() to model probabilities but when looking at the retrodictions after fitting they are not as I expect. I suspect I am doing something fundamentally wrong. Here is a simple example with 5 groups where the first column has very low probability and last column has high.

```
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as random
import numpyro
from numpyro import distributions as dist
from numpyro.infer import NUTS, MCMC, Predictive
import arviz as az
numpyro.set_host_device_count(1)
N_ = 100
X = jax.nn.softmax(np.random.normal(np.arange(-4, 4.1, step=2), scale=2., size=(N_, 5)), axis=-1)
def model1(X=None):
with numpyro.plate('plate_people', 5):
a = numpyro.sample('a', dist.Normal(0, 5.))
b = numpyro.deterministic('b', jax.nn.softplus(a))
obs = numpyro.sample('obs', dist.Dirichlet(b), obs=X)
hmc = MCMC(NUTS(model1, target_accept_prob=0.9, ), num_chains=4, num_warmup=1000, num_samples=1000, progress_bar=True, chain_method='sequential')
hmc.run(random.PRNGKey(32), X=X)
ppc = Predictive(model1, hmc.get_samples())(random.PRNGKey(48), )
idata = az.from_numpyro(hmc, posterior_predictive=ppc)
print(np.asarray(X).mean(0).round(3))
print(idata.posterior_predictive['obs'].mean(('chain', 'draw')).values.round(3))
```

Looking just at the first column, in the data X both the mean and 90th percentile are less than 0.005.

But in the predictions mean is 0.05 and 90th percentile is 0.15

How can I capture this in the model? I’ve tried increasing the number of rows (N) but it doesn’t have any impact.