You might need to relax the Pareto support a bit so that rvs.min()
is in support: the open interval (latent_pareto_m + latent_pareto_shift, infty)
latent_pareto_shift = numpyro.deterministic("latent_pareto_shift", rvs.min() - latent_pareto_m - 1e-10)
(sorry I wasn’t aware of this issue earlier because I frozen some parameters during debugging)
The SkewNormal log prob can be implemented as
@dist.util.validate_sample
def log_prob(self, value):
normalized = (value - self.mu) / self.sigma
return dist.Normal().log_cdf(normalized * self.alpha) - 0.5 * normalized ** 2 - jnp.log(self.sigma * jnp.sqrt(jnp.pi / 2))