Truncated Log normal distribution

Hmm, I believe a rejector will work in your case, as long as you aren’t learning max_x0. It would help to see your error messages, but you should never need to call torch.tensor(another tensor) as in

log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0)).log()

You should be able to either detach or use torch.no_grad() as in

with torch.no_grad():
    log_scale = dist.LogNormal(loc, scale_0).cdf(max_x0).log()

Note I would generally recommend FoldedDistribution over one-sided truncation, e.g. you could achieve something like your prior with

prior = dist.TransformedDistribution(
    dist.FoldedDistribution(dist.Normal(max_x0 - loc, scale)),
    dist.transforms.AffineTransform(-max_x0, -1.))