Hi all,
I’m currently working on creating an SVI model for which I want to use a prior distribution, which is:
- Constrained between [0,1]
- Skewed towards 0
What I have come up with so far is the following:
def generate_skewed_unit_uniform_dist_via_Transform(skewedness:float = 1.):
"""
Generates a zero-skewed distribution of shape 1/x, in the range of [0, 1].
Created by generating from a uniform between [1, 1+skewedness] and taking the results by 1/x, then recentering them between [0, 1].
skewedness must be > 0
Intuitively skewedness lives on a logarithmic scale.
Values of 0.1 are almost uniform on a linear plot, at <= 0.01, they're basically indistinguishable from a normal Uniform distribution.
At 1. the left and right side are doubled/halved respectively to a uniform.
At >= 100 the distribution becomes extremly skewed.
"""
assert skewedness > 0., "Skewedness must be above 0."
uniform_params = (1., 1.+skewedness)
exponent = torch.tensor(-1.)
skewed_upper_lim = torch.tensor(uniform_params[0])**exponent
skewed_lower_lim = torch.tensor(uniform_params[1])**exponent
shift = -skewed_lower_lim
scale = 1/(skewed_upper_lim-skewed_lower_lim)
uniform_dist = dist.Uniform(*uniform_params)
skewed_dist = dist.TransformedDistribution(uniform_dist, dist.transforms.PowerTransform(exponent))
skewed_zeroed_dist = dist.TransformedDistribution(skewed_dist, dist.transforms.AffineTransform(loc=shift, scale=torch.tensor(1.)))
skewed_unit_size_dist = dist.TransformedDistribution(skewed_zeroed_dist, dist.transforms.AffineTransform(loc=torch.tensor(0.), scale=scale))
# Monkey patch in the support, see: https://stackoverflow.com/a/31591589
# ?TODO: Is there a nicer/proper way to do this?
class SkewedWithSupport(dist.TransformedDistribution):
support = dist.constraints.interval(0., 1.)
skewed_unit_size_dist.__class__ = SkewedWithSupport
return skewed_unit_size_dist
This creates the following types of distributions:
This creates the types of distributions I was looking for and appears to work fine in SVI.
However, it significantly different from how the InverseGamma
is created in Pyro (and Torch for that matter): pyro/pyro/distributions/inverse_gamma.py at dev · pyro-ppl/pyro · GitHub
Especially how I monkey patch the support after the transformations appears very wrong to me. Any advice on how to do this properly would be appreciated.
Thanks in advance!