Hi,
I’m wondering if there is a tutorial or gist of code that demonstrates how to leverage pyro to sample from a defined graphical model. I.e.
Say you have the following graph with an associated distribution for each variable and function assigned to the edges, can pyro sample from this graphical model?
C ← A → B → D
If you write your model in such a way that in the end of function you return all the sample variables. Then just using model as function call will allow you to sample.
def Model():
...
return A,B,C
this model can be used with all pyro functions. Calling Model() also samples variables for you.
Example:
def model(rain=None, sprinkler=None, grasswet=None):
if rain is None:
s_rain = pyro.sample('rain', dist.Bernoulli(0.2))
else:
s_rain = pyro.sample('rain', dist.Bernoulli(0.2), obs=rain)
sprinkler_probs = ( 0.01 * s_rain) + (0.4 * (1 - s_rain))
if sprinkler is None:
s_sprinkler = pyro.sample('sprinkler', dist.Bernoulli(sprinkler_probs))
else:
s_sprinkler = pyro.sample('sprinkler', dist.Bernoulli(sprinkler_probs), obs=sprinkler)
grasswet_probs = 0. * (1 - s_sprinkler) * (1 - s_rain) + 0.8 * (1 - s_sprinkler) * s_rain \
+ 0.9 * s_sprinkler * (1 - s_rain) + 0.99 * s_sprinkler * s_rain
if grasswet is None:
s_grasswet = pyro.sample('grasswet', dist.Bernoulli(grasswet_probs))
else:
s_grasswet = pyro.sample('grasswet', dist.Bernoulli(grasswet_probs), obs=grasswet)
return s_rain, s_sprinkler, s_grasswet
# calling model as a function
model()
returns triplet of samples.
One more way to sample is using predictive. In this method you don’t even need to return anything in your model def.
predictive = Predictive(model, guide=guide, num_samples=num_samples)
samples = {
k: v.flatten().detach().numpy()
for k, v in predictive(...).items()
}
Above model for this bayesian network.