Oops, my original suggestion would result in an approximately Normal
capped at max_x0
. For an approximately LogNormal
capped at max_x0
you could instead use
prior = dist.TransformedDistribution(
dist.FoldedDistribution(dist.Normal(max_x0.log() - loc, scale)),
[dist.transforms.AffineTransform(-max_x0.log(), -1.),
dist.transforms.ExpTransform()])
Note neither of these are exactly truncated normal or lognormal; instead of truncation, they fold the tail back on itself. This ends up being cheaper and easier to implement and has approximately the same shape if you are truncating a small portion of the distribution. In my own modeling experience, the truncted/folded distributions are usually qualitative representations of domain knowledge, so the exact form doesn’t matter. But if you have true physical motivation for exactly truncated lognormal then a Rejector might be more appropriate.
Thanks for posting the error message. It looks like .cdf()
doesn’t support float arguments. What if you try
log_scale = dist.LogNormal(loc, scale_0).cdf(torch.as_tensor(max_x0)).log()
Does that work?