TruncatedNormal distribution in Pyro?

Hi I see that numpyro has an univariate truncatednormal distribution, any suggestions for most efficient way I might be able to implement it (or other single/double truncated distributions) pyro?


Pyto has a folded normal Folded(Normal(...), ...) and a HalfNormal , but the PyTorch TruncatedNormal PR has languished for years :person_shrugging:


should I just brute force it by sampling from uniform(a,b), feed into a normal, and then dividing the by renormalizations from the CDFs of a and b? But then I’d need to define the log probability too, wouldn’t I? So, I guess I could subclass Distribution in pyro and try to make it all work…

You have both one sided truncated distribution in NumPyro (LeftTruncatedDistribution) as well as double sided truncated distributions. Do you want to sample or do inference? I was able to implement both right and left truncated Gaussian as well as two-sided truncated and used it for both inference and sample

yes I noticed that, unfortunately I am working in Pyro. I was actually just looking for a guidepost to copy. I am really interested in doing double truncated exponential, making that as a TruncatedExponential class for inference as part of a bigger model. I made some good progress and wrote a function that works for sampling, now I want to make a class that inherits from TorchDistribution. When I have the class code ready I will post for feedback.
EDIT: I followed this post, linking for future people who want to do something similar.

1 Like

@seanreed1111 Truncated exponential is actually simple in torch.distributions, you don’t even need a custom class:

import pyro.distributions as dist
from torch.distributions.utils import broadcast_all
import matplotlib.pyplot as plt

def TruncatedExponential(lb, ub, rate=1.0):
    lb, ub, rate = broadcast_all(lb, ub, rate)
    return dist.TransformedDistribution(
        dist.Uniform((-rate * ub).exp(), (-rate * lb).exp()),
         dist.transforms.AffineTransform(loc=0, scale=-1/rate)]

For example

d = TruncatedExponential(-1, 2, 0.5)
plt.hist(d.sample([100000]).numpy(), density=True, bins=20)


1 Like

Nice! I just did it brute force without TransformedDistribution and AffineTransform/ExpTranform. Thanks.
EDIT: I can do inference with this, right? I need to run pyro.sample and feed this result into a different distribution.

@seanreed1111 yes, you can do inference with the TransformedDistribution wrapper :+1:

awesome, thanks a lot!