Hi there, I am new to Pyro so I hope this question makes sense. I created the following model:
def model():
a1 = pyro.sample("a1", pyro.distributions.Normal(0.0, 1.0))
a2 = pyro.sample("a2", pyro.distributions.Normal(0.0, 1.0))
cond_v = torch.where((a1 + a2) > 5, 1., 0.)
cond = pyro.sample("cond", pyro.distributions.Bernoulli(cond_v))
return cond
For this model, I created a condition which is unlikely to be true (given the samples from the normal distributions). When I run inference on this model and condition on “cond” to be 1.0, then NUTS simply “ignores” the conditioning and samples some values for a1 and a2 so that “cond” will be 0.0. If I change torch.where((a1 + a2) > 5, 1., 0.) to torch.where((a1 + a2) > 4, 1., 0.), meaning it becomes more likely, it works fine. Maybe I am not understanding NUTS entirely, but could someone explain to me why this happens?