Thank You so much for solving my problem!
For me, in this concrete example:
- both fixes are necessary: shifting the Pareto by an epsilon to the left, and choosing a different implementation of SkewNormal log prob
- out of the three options for
log(1 + erf(x))
:- the
log1p_erf
by @fehiepsi above jnp.log(jax.scipy.special.erfc(-x))
dist.Normal().log_cdf(x * jnp.sqrt(2))
the second runs about 30% faster than the other two.
- the