Proper way of defining the support for a TransformedDistribution

Hi all,

I’m currently working on creating an SVI model for which I want to use a prior distribution, which is:

  • Constrained between [0,1]
  • Skewed towards 0

What I have come up with so far is the following:

def generate_skewed_unit_uniform_dist_via_Transform(skewedness:float = 1.):
    """
    Generates a zero-skewed distribution of shape 1/x, in the range of [0, 1].
    Created by generating from a uniform between [1, 1+skewedness] and taking the results by 1/x, then recentering them between [0, 1].
    
    skewedness must be > 0
        Intuitively skewedness lives on a logarithmic scale.
        Values of 0.1 are almost uniform on a linear plot, at <= 0.01, they're basically indistinguishable from a normal Uniform distribution.
        At 1. the left and right side are doubled/halved respectively to a uniform.
        At >= 100 the distribution becomes extremly skewed.
    """
    assert skewedness > 0., "Skewedness must be above 0."
    uniform_params = (1., 1.+skewedness)
    exponent = torch.tensor(-1.)
    
    skewed_upper_lim = torch.tensor(uniform_params[0])**exponent
    skewed_lower_lim = torch.tensor(uniform_params[1])**exponent
    shift = -skewed_lower_lim
    scale = 1/(skewed_upper_lim-skewed_lower_lim)
    
    uniform_dist = dist.Uniform(*uniform_params)
    skewed_dist = dist.TransformedDistribution(uniform_dist, dist.transforms.PowerTransform(exponent))
    skewed_zeroed_dist = dist.TransformedDistribution(skewed_dist, dist.transforms.AffineTransform(loc=shift, scale=torch.tensor(1.)))
    skewed_unit_size_dist = dist.TransformedDistribution(skewed_zeroed_dist, dist.transforms.AffineTransform(loc=torch.tensor(0.), scale=scale))
    
    # Monkey patch in the support, see: https://stackoverflow.com/a/31591589
    # ?TODO: Is there a nicer/proper way to do this?
    class SkewedWithSupport(dist.TransformedDistribution):
        support = dist.constraints.interval(0., 1.)
    skewed_unit_size_dist.__class__ = SkewedWithSupport
    
    return skewed_unit_size_dist

This creates the following types of distributions:

This creates the types of distributions I was looking for and appears to work fine in SVI.

However, it significantly different from how the InverseGamma is created in Pyro (and Torch for that matter): pyro/pyro/distributions/inverse_gamma.py at dev · pyro-ppl/pyro · GitHub

Especially how I monkey patch the support after the transformations appears very wrong to me. Any advice on how to do this properly would be appreciated.

Thanks in advance!