I saw that pyro is planning to add at least a truncated normal distribution soon. However I want to implement a truncated normal distribution as prior for a sample param. I came accross the rejector distribution and thought this could maybe provide a solution. I tried the following:
class TruncatedLogNormal(dist.Rejector): def __init__(self, loc, scale_0, max_x0): propose = dist.LogNormal(loc, scale_0) def log_prob_accept(x): return (x < max_x0).type_as(x).log() log_scale = torch.tensor(dist.LogNormal(loc, scale_0).cdf(max_x0)).log() super(TruncatedLogNormal, self).__init__(propose, log_prob_accept, log_scale)
Using this model:
(and I can assure that the dimensions of x_m and speed are matching)
def model_pyro_logit(x_c, x_m, y, pot_loc, pot_scl, spd_loc, spd_scl): c = x_c.shape m = x_m.shape n = x_m.shape alpha = pyro.sample('alpha', dist.Normal(0.0, 1.0)) beta_c = pyro.sample('lambda', dist.Normal(torch.zeros(c), torch.ones(c)).to_event(1)) speed_prior = TruncatedLogNormal(spd_loc, spd_scl, 3.0) speed = pyro.param('f_speed_pooled', f_speed_prior) y_loc = alpha + x_c.matmul(beta_c.unsqueeze(-1)).squeeze(-1) + (x_m * speed.unsqueeze(-1)) with pyro.plate('data', n): pyro.sample('y', dist.Bernoulli(logits=y_loc), obs=y)
However the approach seems to be not valid as it says the module is not callable and complains also about the log in log_scale. Any ideas how to make this work or alternative solutions?