Hmm, I believe a rejector will work in your case, as long as you aren’t learning max_x0
. It would help to see your error messages, but you should never need to call torch.tensor(another tensor)
as in
log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0)).log()
You should be able to either detach or use torch.no_grad()
as in
with torch.no_grad():
log_scale = dist.LogNormal(loc, scale_0).cdf(max_x0).log()
Note I would generally recommend FoldedDistribution over one-sided truncation, e.g. you could achieve something like your prior with
prior = dist.TransformedDistribution(
dist.FoldedDistribution(dist.Normal(max_x0 - loc, scale)),
dist.transforms.AffineTransform(-max_x0, -1.))