Truncated Log normal distribution

I saw that pyro is planning to add at least a truncated normal distribution soon. However I want to implement a truncated normal distribution as prior for a sample param. I came accross the rejector distribution and thought this could maybe provide a solution. I tried the following:

class TruncatedLogNormal(dist.Rejector):
    def __init__(self, loc, scale_0, max_x0):
        propose = dist.LogNormal(loc, scale_0)

        def log_prob_accept(x):
            return (x[0] < max_x0).type_as(x).log()

        log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0)).log()
        super(TruncatedLogNormal, self).__init__(propose, log_prob_accept, log_scale)

Using this model:
(and I can assure that the dimensions of x_m and speed are matching)

def model_pyro_logit(x_c, x_m, y, pot_loc, pot_scl, spd_loc, spd_scl):
    c = x_c.shape[1]
    m = x_m.shape[1]
    n = x_m.shape[0]
    alpha = pyro.sample('alpha', dist.Normal(0.0, 1.0))
    beta_c = pyro.sample('lambda', dist.Normal(torch.zeros(c), torch.ones(c)).to_event(1))

    speed_prior = TruncatedLogNormal(spd_loc, spd_scl, 3.0)
    speed = pyro.param('f_speed_pooled', f_speed_prior)
    y_loc = alpha + x_c.matmul(beta_c.unsqueeze(-1)).squeeze(-1) + (x_m * speed.unsqueeze(-1))

    with pyro.plate('data', n):
        pyro.sample('y', dist.Bernoulli(logits=y_loc), obs=y)

However the approach seems to be not valid as it says the module is not callable and complains also about the log in log_scale. Any ideas how to make this work or alternative solutions?

Hmm, I believe a rejector will work in your case, as long as you aren’t learning max_x0. It would help to see your error messages, but you should never need to call torch.tensor(another tensor) as in

log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0)).log()

You should be able to either detach or use torch.no_grad() as in

with torch.no_grad():
    log_scale = dist.LogNormal(loc, scale_0).cdf(max_x0).log()

Note I would generally recommend FoldedDistribution over one-sided truncation, e.g. you could achieve something like your prior with

prior = dist.TransformedDistribution(
    dist.FoldedDistribution(dist.Normal(max_x0 - loc, scale)),
    dist.transforms.AffineTransform(-max_x0, -1.))

Thanks for your reply. I am not familiar with the transformed distribution functions and especially the affinetransform. Will the way you specify it result in a LogNormal(loc,scale_0) capped at max_x0? Or is it a halfnormal distribution which is truncated?

When it do:
log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0)).log()

it results in the error below also when I switch to torch.log()

File “/opt/project/scripts_methods/FED_models.py”, line 101, in init
log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0))
File “/usr/local/lib/python3.6/dist-packages/torch/distributions/transformed_distribution.py”, line 138, in cdf
value = transform.inv(value)
File “/usr/local/lib/python3.6/dist-packages/torch/distributions/transforms.py”, line 207, in call
return self._inv._inv_call(x)
File “/usr/local/lib/python3.6/dist-packages/torch/distributions/transforms.py”, line 138, in _inv_call
return self._inverse(y)
File “/usr/local/lib/python3.6/dist-packages/torch/distributions/transforms.py”, line 313, in _inverse
return y.log()
AttributeError: ‘float’ object has no attribute ‘log’

Process finished with exit code 1

I think it has something to do with the cdf function in lognormal. As this: dist.LogNormal(loc, scale_0).cdf(max_x0) results in an error but this dist.Normal(loc, scale_0).cdf(max_x0) does not for the same input variables

Oops, my original suggestion would result in an approximately Normal capped at max_x0. For an approximately LogNormal capped at max_x0 you could instead use

prior = dist.TransformedDistribution(
    dist.FoldedDistribution(dist.Normal(max_x0.log() - loc, scale)),
    [dist.transforms.AffineTransform(-max_x0.log(), -1.),
     dist.transforms.ExpTransform()])

Note neither of these are exactly truncated normal or lognormal; instead of truncation, they fold the tail back on itself. This ends up being cheaper and easier to implement and has approximately the same shape if you are truncating a small portion of the distribution. In my own modeling experience, the truncted/folded distributions are usually qualitative representations of domain knowledge, so the exact form doesn’t matter. But if you have true physical motivation for exactly truncated lognormal then a Rejector might be more appropriate.

Thanks for posting the error message. It looks like .cdf() doesn’t support float arguments. What if you try

log_scale = dist.LogNormal(loc, scale_0).cdf(torch.as_tensor(max_x0)).log()

Does that work?

Well I implemented your suggested change and it does not complain about the log anymore which is great but it is complaining about: AttributeError: ‘TruncatedLogNormal’ object has no attribute ‘_event_shape’

which I saw you created a MR for (Add super().__init__() call to Rejector by fritzo ¡ Pull Request #2389 ¡ pyro-ppl/pyro ¡ GitHub) I am currently using pyro version 1.3.0 so I gues updating it to 1.4.0 will fix the problem :slight_smile:

Oh also in case you want to use HMC or an autoguide, you’ll need to set proper .support attribute. I think you can do this via

from torch.distributions import constraints

class TruncatedLogNormal(dist.Rejector):
    ...
    @constraints.dependent_property
    def support(self):
        return constraints.interval(0, self.max_x0)

That works if all your dims are truncated. If only some of your dims are trunctated I think you could constraints.cat together a constraints.interval with a constraints.greater_than, like

@constraints.dependent_property
def support(self):
    return constraints.cat([
        constraints.interval(torch.tensor([0.]),
                             torch.tensor([self.max_x0])),
        constraints.greater_than(torch.zeros(self.event_shape[1:])])
1 Like

Is this also the case when I will use your suggestion (using transformedDistribution and AffineTransform)?

Yes, you’ll also need to add a .support property when using a TransformedDistribution, if you want to use an autoguide or MCMC.

Worked perfectly!

@fritzo Is there any way to get this to work for Numpyro as well? as I can’t find the rejectorimplementation in numpyro

Hi @Helena.H, I think this should actually be easier in NumPyro since NumPyro supports TruncatedNormal distributions. I think you can use something like

import jax.numpy as jnp
import numpyro.distributions as dist

dist.TransformedDistribution(
    dist.TruncatedNormal(low=jnp.log(lower_bound), loc=loc, scale=scale),
    dist.ExpTransform())

If that doesn’t compute the correct .support, you could put in a bit more work and create a subclass

class TruncatedLogNormal(dist.TransformedDistribution):
    def __init__(self, low, loc, scale, validate_args=None):
        base_dist = dist.TruncatedNormal(low=jnp.log(low), loc=loc, scale=scale)
        transform = dist.ExpTransform()
        self.support = dist.constraints.greater_than(low)
        super().__init__(base_dist, transform, validate_args=validate_args)
1 Like