Hi all,

I was just doing some sanity checks and noticed a possible issue with the log_prob of the Dirichlet distribution when calculating the log probability at the vertices of the simplex. Running the code:

```
import numpy as np
from scipy.stats import dirichlet
import numpyro
print("SciPy: ", dirichlet.logpdf(np.array([1.0,0.0,0.0]),np.ones((3,))))
print("NumPyro: ", numpyro.distributions.Dirichlet(np.ones((3,))).log_prob(np.array([1.0,0.0,0.0])))
```

gives the following output:

```
SciPy: 0.6931471805599453
NumPyro: nan
```

I think this can be made consistent by changing jnp.sum to jnp.nansum on this line, but I’m not sure if this is the expected behaviour, or if it will cause issues when inputs should give a nan value.

Thank you for any help!