Hello everyone,

I am not sure if this is the right place (is Github better?) but I have been struggling with an issue that, I believe is caused by machine precision and how Numpyro implements a simplex constraint.

Consider this minimal example:

```
import jax
import jax.numpy as jnp
import numpyro.distributions as dist
n_topics = 10
n_docs = 50
alpha = 20/n_topics
key = jax.random.PRNGKey(0)
theta = dist.Dirichlet(jnp.ones([n_topics]) * alpha).sample(key, sample_shape = (n_docs,))
jnp.all(dist.constraints.simplex(theta))
```

Unfortunately this results in *False*, i.e. draws from Dirichlet do not lie in a simplex (this is regardless whether use_x64 is set to true or false).

I notice that Pyro implements simplex constraint like this:

```
torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
```

While Numpyro opts for:

```
x_sum = jnp.sum(x, axis=-1)
jnp.all(x > 0, axis=-1) & (x_sum <= 1) & (x_sum > 1 - 1e-6)
```

So the difference seems to be that when `1 < x_sum < 1 + 1e-6`

, Pyro will allow it but Numpyro will not.

Is this intentional? It seems to cause a problem when:

```
theta = numpyro.sample("theta",dist.Dirichlet(jnp.ones([n_topics]))
```

since the program rejects draws for theta on the basis of them not satisfying the simplex constraint.

Is there anything I can do to make the above sampling statement work?

Many thanks,

Elchorro