Basic conditioning example not working

Hi all,

I seem to be misunderstanding how pyro.condition works. I have a basic example in which one variable x determines another variable y:

def test():
  x = pyro.sample('x', dist.Bernoulli(0.5)).item()
  prob_y = 1 if x == 1 else 0
  y = pyro.sample('y', dist.Bernoulli(prob_y)).item()
  return x, y

pyro.condition(test, data={'y': torch.tensor(1.0)})()

The variables x and y should always have the same value, but after conditioning on y=1, the above frequently returns (0.0, 1.0).

Can anyone point out what I am doing incorrectly?

Thanks!

condition here effectively turns

y = pyro.sample('y', dist.Bernoulli(prob_y)).item()

into

y = pyro.sample('y', dist.Bernoulli(prob_y), obs=torch.tensor(1.0)).item()

that is it rewrites the given model into a different model in which y is observed. it does not do inference. to find values of x that are consistent with the observed value of y requires inference, whether variational inference or otherwise.
`

Thanks @martinjankowiak. I tried to modify the above code to do inference and draw samples from the conditioned model:

def test():
  x = pyro.sample('x', dist.Bernoulli(0.5))
  prob_y = 1 if x == 1 else 0
  y = pyro.sample('y', dist.Bernoulli(prob_y))
return x, y

model = config_enumerate(test, 'sequential')
model = pyro.condition(model, data={'y': torch.tensor(1.0)})
model = infer_discrete(model, first_available_dim=-1, temperature=1)
model()

However, I still frequently get (0.0, 1.0). Could you point me in the right direction?

please start with the tutorials

Hi @slm,

I think you are seeing (0.0, 1.0) because of your use of config_enumerate, see the Inference with Discrete Latent Variables - Mechanics of enumeration