NUTS ignores conditioning

Hi there, I am new to Pyro so I hope this question makes sense. I created the following model:

  def model():
       a1 = pyro.sample("a1", pyro.distributions.Normal(0.0, 1.0))
       a2 = pyro.sample("a2", pyro.distributions.Normal(0.0, 1.0))
       cond_v = torch.where((a1 + a2) > 5, 1., 0.)
       cond = pyro.sample("cond", pyro.distributions.Bernoulli(cond_v))
       return cond

For this model, I created a condition which is unlikely to be true (given the samples from the normal distributions). When I run inference on this model and condition on “cond” to be 1.0, then NUTS simply “ignores” the conditioning and samples some values for a1 and a2 so that “cond” will be 0.0. If I change torch.where((a1 + a2) > 5, 1., 0.) to torch.where((a1 + a2) > 4, 1., 0.), meaning it becomes more likely, it works fine. Maybe I am not understanding NUTS entirely, but could someone explain to me why this happens?

Here is how I use NUTS:

def conditioned_model(model, data):
    return poutine.condition(model, data=data)()

kernel = NUTS(conditioned_model, jit_compile=True)
mcmc = MCMC(kernel,
                num_samples=200,
                warmup_steps=50,
                num_chains=1,
                mp_context="spawn")

data_to_condition = {"cond": torch.tensor(1.0)}

mcmc.run(model, data_to_condition)
mcmc.summary(prob=0.9)

the log density of your model is not differentiable w.r.t. a1 and a2. hmc/nuts requires differentiability for correctness. in some cases it may give reasonableish results if this condition is not satisfied, but bad things can happen if it is not.

1 Like