How does pyro.condition work?

I’m working my way through this tutorial: An Introduction to Inference in Pyro

What I don’t understand is the following. In order to get (𝗐𝖾𝗂𝗀𝗁𝗍|π—€π—Žπ–Ύπ—Œπ—Œ,π—†π–Ύπ–Ίπ—Œπ—Žπ—‹π–Ύπ—†π–Ύπ—‡π—=9.5) we can use the pyro.condition function with

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

and conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})

I wrote the following script:

    scale(0.3) # tensor(-1.0905)
    conditioned_scale(0.3) # tensor(-1.0905)

For both functions we get the same sample for the weight. Isn’t this tutorial saying that with conditioned_scale we’re getting a sample from a weight distribution that is conditioned on measurement=9.5? If so, shouldn’t the samples of the weight be different, because in the first call we don’t observe any data but in the second we condition on data?


I believe the second call should return 9.5 as you expected. Probably there is some mismatch in the code.

Sorry for the misleading example. With -1.0905 I’m refering to the chosen weight from pyro.sample("weight", dist.Normal(guess, 1.0)) and not to the return value.

I’m confused that in both cases the same weight is chosen, although the tutorial says that with conditioned_scale we get (𝗐𝖾𝗂𝗀𝗁𝗍|π—€π—Žπ–Ύπ—Œπ—Œ,π—†π–Ύπ–Ίπ—Œπ—Žπ—‹π–Ύπ—†π–Ύπ—‡π—=9.5) and the first example should only give (𝗐𝖾𝗂𝗀𝗁𝗍|π—€π—Žπ–Ύπ—Œπ—Œ) , namely to different things, imo

pyro.condition is used to constrain values of some sample statement. If you want to get samples from conditional distributions, you can use inference algorithms like SVI or MCMC on that conditioned model. A Pyro model is just a Python function that gets inputs, executes each single line, and returns output. Because you are using the same seed, the first sample statement will return the same value.