I saw that pyro is planning to add at least a truncated normal distribution soon. However I want to implement a truncated normal distribution as prior for a sample param. I came accross the rejector distribution and thought this could maybe provide a solution. I tried the following:
However the approach seems to be not valid as it says the module is not callable and complains also about the log in log_scale. Any ideas how to make this work or alternative solutions?
Hmm, I believe a rejector will work in your case, as long as you arenât learning max_x0. It would help to see your error messages, but you should never need to call torch.tensor(another tensor) as in
Thanks for your reply. I am not familiar with the transformed distribution functions and especially the affinetransform. Will the way you specify it result in a LogNormal(loc,scale_0) capped at max_x0? Or is it a halfnormal distribution which is truncated?
When it do:
log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0)).log()
it results in the error below also when I switch to torch.log()
File â/opt/project/scripts_methods/FED_models.pyâ, line 101, in init
log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0))
File â/usr/local/lib/python3.6/dist-packages/torch/distributions/transformed_distribution.pyâ, line 138, in cdf
value = transform.inv(value)
File â/usr/local/lib/python3.6/dist-packages/torch/distributions/transforms.pyâ, line 207, in call
return self._inv._inv_call(x)
File â/usr/local/lib/python3.6/dist-packages/torch/distributions/transforms.pyâ, line 138, in _inv_call
return self._inverse(y)
File â/usr/local/lib/python3.6/dist-packages/torch/distributions/transforms.pyâ, line 313, in _inverse
return y.log()
AttributeError: âfloatâ object has no attribute âlogâ
Process finished with exit code 1
I think it has something to do with the cdf function in lognormal. As this: dist.LogNormal(loc, scale_0).cdf(max_x0) results in an error but this dist.Normal(loc, scale_0).cdf(max_x0) does not for the same input variables
Oops, my original suggestion would result in an approximately Normal capped at max_x0. For an approximately LogNormal capped at max_x0 you could instead use
Note neither of these are exactly truncated normal or lognormal; instead of truncation, they fold the tail back on itself. This ends up being cheaper and easier to implement and has approximately the same shape if you are truncating a small portion of the distribution. In my own modeling experience, the truncted/folded distributions are usually qualitative representations of domain knowledge, so the exact form doesnât matter. But if you have true physical motivation for exactly truncated lognormal then a Rejector might be more appropriate.
Thanks for posting the error message. It looks like .cdf() doesnât support float arguments. What if you try
Well I implemented your suggested change and it does not complain about the log anymore which is great but it is complaining about: AttributeError: âTruncatedLogNormalâ object has no attribute â_event_shapeâ
Oh also in case you want to use HMC or an autoguide, youâll need to set proper .support attribute. I think you can do this via
from torch.distributions import constraints
class TruncatedLogNormal(dist.Rejector):
...
@constraints.dependent_property
def support(self):
return constraints.interval(0, self.max_x0)
That works if all your dims are truncated. If only some of your dims are trunctated I think you could constraints.cat together a constraints.interval with a constraints.greater_than, like
Hi @Helena.H, I think this should actually be easier in NumPyro since NumPyro supports TruncatedNormal distributions. I think you can use something like
import jax.numpy as jnp
import numpyro.distributions as dist
dist.TransformedDistribution(
dist.TruncatedNormal(low=jnp.log(lower_bound), loc=loc, scale=scale),
dist.ExpTransform())
If that doesnât compute the correct .support, you could put in a bit more work and create a subclass