How to implement Generalized Extreme Value distribution in NumPyro?

Hello devs.

I’m looking to use GEV distribution for observation noise. I have implemented the log_prob method, but not sure how to implement the sample method. I have to implement it because I’m using this distribution for observation noise… (so, I need the sample method to be able to simulate fake observations from the model)

This is what I have so far…

from jax import lax
import jax.numpy as jnp
import jax.random as random

from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes


class GeneralizedExtremeValue(Distribution):
    arg_constraints = {
        "loc": constraints.real,
        "scale": constraints.positive,
        "shape": constraints.real
    }
    support = None

    def __init__(self, loc, scale, shape, *, validate_args=None):
        batch_shape = lax.broadcast_shapes(
            jnp.shape(loc), jnp.shape(scale), jnp.shape(shape)
        )
        self.loc, self.scale, self.shape = promote_shapes(
            loc, scale, shape, shape=batch_shape
        )
        support = jnp.where(
            self.shape > 0,
            constraints.greater_than(self.loc - (self.scale / self.shape)),
            constraints.real
        )
        support = jnp.where(
            self.shape < 0,
            constraints.less_than(self.loc - (self.scale / self.shape)),
            support
        )
        super(GeneralizedExtremeValue, self).__init__(batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def log_prob(self, value):
        z = (value - self.loc) / self.scale
        t = jnp.where(
            self.shape != 0,
            jnp.power(1 + (self.shape * z), -1 / self.shape),
            jnp.exp(z)
        )
        return (
            - jnp.log(self.scale)
            + (self.shape + 1) * jnp.log(t)
            - t
        )

looks like you may be able to use the tfp distribution

tfp.distributions.GeneralizedExtremeValue  |  TensorFlow Probability.

https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution

1 Like

@martinjankowiak thank you. This worked for me!