I was just doing some sanity checks and noticed a possible issue with the log_prob of the Dirichlet distribution when calculating the log probability at the vertices of the simplex. Running the code:
import numpy as np from scipy.stats import dirichlet import numpyro print("SciPy: ", dirichlet.logpdf(np.array([1.0,0.0,0.0]),np.ones((3,)))) print("NumPyro: ", numpyro.distributions.Dirichlet(np.ones((3,))).log_prob(np.array([1.0,0.0,0.0])))
gives the following output:
SciPy: 0.6931471805599453 NumPyro: nan
I think this can be made consistent by changing jnp.sum to jnp.nansum on this line, but I’m not sure if this is the expected behaviour, or if it will cause issues when inputs should give a nan value.
Thank you for any help!