Numpyro's implementation of constraint.simplex seems problematic

Hello everyone,

I am not sure if this is the right place (is Github better?) but I have been struggling with an issue that, I believe is caused by machine precision and how Numpyro implements a simplex constraint.

Consider this minimal example:

    import jax
    import jax.numpy as jnp
    import numpyro.distributions as dist
    n_topics = 10 
    n_docs = 50 
    alpha = 20/n_topics

    key = jax.random.PRNGKey(0)

    theta = dist.Dirichlet(jnp.ones([n_topics]) * alpha).sample(key, sample_shape = (n_docs,)) 
    jnp.all(dist.constraints.simplex(theta))

Unfortunately this results in False, i.e. draws from Dirichlet do not lie in a simplex (this is regardless whether use_x64 is set to true or false).

I notice that Pyro implements simplex constraint like this:

    torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)

While Numpyro opts for:

    x_sum = jnp.sum(x, axis=-1)
    jnp.all(x > 0, axis=-1) & (x_sum <= 1) & (x_sum > 1 - 1e-6)

So the difference seems to be that when 1 < x_sum < 1 + 1e-6, Pyro will allow it but Numpyro will not.

Is this intentional? It seems to cause a problem when:

    theta = numpyro.sample("theta",dist.Dirichlet(jnp.ones([n_topics]))

since the program rejects draws for theta on the basis of them not satisfying the simplex constraint.

Is there anything I can do to make the above sampling statement work?

Many thanks,
Elchorro

Hi @Elchorro, you are right that there are differences between the two constraints. Could you submit a PR to relax it? Thanks!

1 Like