Hello everyone,
I am not sure if this is the right place (is Github better?) but I have been struggling with an issue that, I believe is caused by machine precision and how Numpyro implements a simplex constraint.
Consider this minimal example:
import jax
import jax.numpy as jnp
import numpyro.distributions as dist
n_topics = 10
n_docs = 50
alpha = 20/n_topics
key = jax.random.PRNGKey(0)
theta = dist.Dirichlet(jnp.ones([n_topics]) * alpha).sample(key, sample_shape = (n_docs,))
jnp.all(dist.constraints.simplex(theta))
Unfortunately this results in False, i.e. draws from Dirichlet do not lie in a simplex (this is regardless whether use_x64 is set to true or false).
I notice that Pyro implements simplex constraint like this:
torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
While Numpyro opts for:
x_sum = jnp.sum(x, axis=-1)
jnp.all(x > 0, axis=-1) & (x_sum <= 1) & (x_sum > 1 - 1e-6)
So the difference seems to be that when 1 < x_sum < 1 + 1e-6
, Pyro will allow it but Numpyro will not.
Is this intentional? It seems to cause a problem when:
theta = numpyro.sample("theta",dist.Dirichlet(jnp.ones([n_topics]))
since the program rejects draws for theta on the basis of them not satisfying the simplex constraint.
Is there anything I can do to make the above sampling statement work?
Many thanks,
Elchorro