Hello devs.
I’m looking to use GEV distribution for observation noise. I have implemented the log_prob method, but not sure how to implement the sample method. I have to implement it because I’m using this distribution for observation noise… (so, I need the sample method to be able to simulate fake observations from the model)
This is what I have so far…
from jax import lax
import jax.numpy as jnp
import jax.random as random
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes
class GeneralizedExtremeValue(Distribution):
arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
"shape": constraints.real
}
support = None
def __init__(self, loc, scale, shape, *, validate_args=None):
batch_shape = lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale), jnp.shape(shape)
)
self.loc, self.scale, self.shape = promote_shapes(
loc, scale, shape, shape=batch_shape
)
support = jnp.where(
self.shape > 0,
constraints.greater_than(self.loc - (self.scale / self.shape)),
constraints.real
)
support = jnp.where(
self.shape < 0,
constraints.less_than(self.loc - (self.scale / self.shape)),
support
)
super(GeneralizedExtremeValue, self).__init__(batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
raise NotImplementedError
def log_prob(self, value):
z = (value - self.loc) / self.scale
t = jnp.where(
self.shape != 0,
jnp.power(1 + (self.shape * z), -1 / self.shape),
jnp.exp(z)
)
return (
- jnp.log(self.scale)
+ (self.shape + 1) * jnp.log(t)
- t
)