How does conditional inference works?

Hi!

I’m studying Introduction to Inference now ((DEPRECATED) An Introduction to Inference in Pyro — Pyro Tutorials 1.8.6 documentation).
I wrote a simple example with normal distribution:

def model():
X = pyro.sample(“X”, dist.Normal(3, 1))
Y = pyro.sample(“Y”, dist.Normal(5, 1))
S = pyro.deterministic(“S”, X+Y)
return X, Y, S

conditioned_model = pyro.condition(model, data={“S”: 15})
conditioned_model()
Out: (tensor(4.1030), tensor(5.9444), 15)

N = 1000
X, Y = ,
for _ in range(N):
c = conditioned_model()
X.append(float(c[0]))
Y.append(float(c[1]))

np.mean(X), np.mean(Y)
Out: (2.9528431569337843, 4.997446128606796)

do_model = pyro.do(model, data={“S”: 15})
do_model()
Out: (tensor(2.2040), tensor(4.0202), 15)

N = 1000
X, Y = ,
for _ in range(N):
c = conditioned_model()
X.append(float(c[0]))
Y.append(float(c[1]))

np.mean(X), np.mean(Y)
Out: (2.9624990526437758, 5.0139990842342375)

In the first case I want to sample from a conditional distribution (X, Y)|X+Y=15. But 4.1030+5.9444 not equal 15.

Why does it work like this? Or what am I doing wrong?

Why does it work like this? Or what am I doing wrong?

In general, to sample from the posterior distribution of a model’s sample statements given data from pyro.condition, you need to wrap your conditioned model in an inference algorithm like HMC or SVI.

In your particular case, note that Pyro doesn’t support conditioning on the value of a pyro.deterministic statement, so the queries in your code will not produce meaningful results regardless of inference algorithm. See this old forum topic for previous discussion of conditioning on determinstic computations.

1 Like