Possible dirichlet log_prob issue

Hi all,

I was just doing some sanity checks and noticed a possible issue with the log_prob of the Dirichlet distribution when calculating the log probability at the vertices of the simplex. Running the code:

import numpy as np
from scipy.stats import dirichlet
import numpyro

print("SciPy: ", dirichlet.logpdf(np.array([1.0,0.0,0.0]),np.ones((3,))))
print("NumPyro: ", numpyro.distributions.Dirichlet(np.ones((3,))).log_prob(np.array([1.0,0.0,0.0])))

gives the following output:

SciPy:  0.6931471805599453
NumPyro:  nan

I think this can be made consistent by changing jnp.sum to jnp.nansum on this line, but I’m not sure if this is the expected behaviour, or if it will cause issues when inputs should give a nan value.

Thank you for any help!

Thanks! Yeah nansum would work, but think we need to clip value before calling the log there (in sample method, we clip the values). Do you want to submit the fix or create a github issue for this?