Custom generalized extreme value distribution

Hello team,
I am writing a custom GEV distribution for my data and I am not sure about the rsample function. The logp part works but for SVI I need rsample. Any help in this regard will be greatly appreciated.

class GEV(pyro.distributions.TorchDistribution):

def __init__(self,loc,scale,shape):
    super(pyro.distributions.TorchDistribution, self).__init__(
            batch_shape=self.loc.shape, event_shape=torch.Size([]))
def expand(self, batch_shape, _instance=None):
    new = self._get_checked_instance(GEV, _instance)
    batch_shape = torch.Size(batch_shape)
    new.loc = self.loc.expand(batch_shape)
    new.scale = self.scale.expand(batch_shape)
    new.shape = self.shape.expand(batch_shape)
    super(GEV, new).__init__(batch_shape, self.event_shape, validate_args=False)
    new._validate_args = self._validate_args
    return new

def rsample(self,sample_shape=torch.Size()):
def log_prob(self,value):
    logp= -(log_scale+((self.shape+1)/self.shape)*torch.log1p(self.shape*beta)+

Hi @Sree, the .rsample,() method draws as random sample in a differentiable way by first drawing parameter free noise from a fixed distribution (say Exponential(1)) then differentiably transforming the parameter-free noise to the desired distribution. Looking at Wikipedia it looks like something like this might work for you:

def rsample(self, sample_shape=torch.Size()):
    noise = torch.empty(sample_shape + self.batch_shape).exponential_()
    return self.loc - self.scale * noise.log()

It would be great to have a GEV implementation in Pyro in case you want to submit a pull request!

Thank you @fritzo I did implement this. But when I run SVI with AutoDiagonalNormal guide it gives error “ValueError: The parameter loc has invalid values”. Do we need to write our own guide while using Custom distribution?

Hi @Sree, your distribution should work fine with AutoDiagonalNormal or other autoguides. My guess is that either (i) your learning rate is too high, or (ii) you are getting NANs somewhere and need to add some clamping in the distribution (most of our distributions require some sort of numerical clamping to avoid NANs). You can check for nans in a debugger where the error occurs:


Here’s a possible way to clamp to avoid NANs:

def rsample(self, sample_shape=torch.Size()):
    noise = torch.empty(sample_shape + self.batch_shape).exponential_()
    noise = noise.clamp(min=torch.finfo(noise.dtype).tiny)
    return self.loc - self.scale * noise.log()

BTW the .rsanple() I suggested above didn’t use your shape parameter, so you’ll need to adjust the math. Also I believe .shape is a reserved method of TorchDistribution, and you may need to rename that parameter, say to shape_.

Thank you Fritzo for the quick response. I am enjoying working with pyro it is an amazing package.

1 Like