I seem to be misunderstanding how pyro.condition works. I have a basic example in which one variable x determines another variable y:
def test():
x = pyro.sample('x', dist.Bernoulli(0.5)).item()
prob_y = 1 if x == 1 else 0
y = pyro.sample('y', dist.Bernoulli(prob_y)).item()
return x, y
pyro.condition(test, data={'y': torch.tensor(1.0)})()
The variables x and y should always have the same value, but after conditioning on y=1, the above frequently returns (0.0, 1.0).
y = pyro.sample('y', dist.Bernoulli(prob_y)).item()
into
y = pyro.sample('y', dist.Bernoulli(prob_y), obs=torch.tensor(1.0)).item()
that is it rewrites the given model into a different model in which y is observed. it does not do inference. to find values of x that are consistent with the observed value of y requires inference, whether variational inference or otherwise.
`
Thanks @martinjankowiak. I tried to modify the above code to do inference and draw samples from the conditioned model:
def test():
x = pyro.sample('x', dist.Bernoulli(0.5))
prob_y = 1 if x == 1 else 0
y = pyro.sample('y', dist.Bernoulli(prob_y))
return x, y
model = config_enumerate(test, 'sequential')
model = pyro.condition(model, data={'y': torch.tensor(1.0)})
model = infer_discrete(model, first_available_dim=-1, temperature=1)
model()
However, I still frequently get (0.0, 1.0). Could you point me in the right direction?