I’m trying to get started learning Pyro. The documentation looks very good but it’s also quite confusing near the beginning.
So I’m working with one of the first models introduced, the weather
one. Here’s my version of it:
def weather():
sky = pyro.sample("sky",
pyro.distributions.Bernoulli(0.3)
)
sky = {
0.0: "cloudy",
1.0: "sunny"
}[sky.item()]
temp_mean = {"cloudy": 55.0, "sunny": 75.0}[sky]
temp_scale = {"cloudy": 10.0, "sunny": 15.0}[sky]
temp = pyro.sample("temp",
pyro.distributions.Normal(temp_mean,temp_scale)
)
return sky, temp
I can create a conditioned version like this
conditioned_weather = pyro.condition(
weather,
data={"temp": torch.tensor(80)}
)
but I don’t now know how to actually evaluate this. Calling conditioned_weather
returns either ('sunny', tensor(80))
or ('cloudy', tensor(80))
, but not with the high frequency of sunny
outcomes that I might expect.
I understand that the reason for this is that I need to create a “guide” (variational distirbution) and optimise that, but I’m having trouble seeing how to do that because the examples all seem to be doing something more complicated, with multiple observations instead of a single conditioning step. I’m also not sure what the guide should be in this case, because I don’t really want it to be different from the model (or at least I don’t think I do) - I really just want to find the conditional distribution.
Would someone be able to work me through the basic steps of how to do this simple task with Pyro?