Hello there!
I am new to Pyro, and I wish to perform variational inference on a Bayesian network. Its nodes are indexed by a time j
and a station i
. In the code below, BN
is the Networkx graph object containing my network. BN
also contains the parameters defining the conditional distributions for each node n = (i, j)
(they are linear gaussian).
def BN_model():
values = {}
for j in pyro.markov(times):
for i in pyro.plate("time_{}".format(j), len(stations)):
n = (i, j)
parents = list(BN.predecessors(n))
loc = (
BN.nodes[n]["mu"] +
sum(BN.edges[p, n]["weight"] * values[str(p)] for p in parents)
)
scale = BN.nodes[n]["sigma"]
values[str(n)] = pyro.sample(str(n), dist.Normal(loc=loc, scale=scale))
return values
def BN_guide():
nodes = list(BN.nodes())
for k in pyro.plate("nodes", len(nodes)):
n = nodes[k]
loc = pyro.param("loc_{}".format(n), torch.tensor(0.))
scale = pyro.param("scale_{}".format(n), torch.tensor(1.))
pyro.sample(str(n), dist.Normal(loc=loc, scale=scale))
When using SVI on this, the inference works but I get a warning I do not understand:
UserWarning: Found plate statements in guide but not model: {'nodes'}
It seems this warning is due to the fact that I use different independence structures in my model and guide. However, that is precisely what I want to do:
- In the model, I want all the variables in a given time to be mutually independent conditionally on the past
- In the guide, I want all variables in the entire network to be mutually independent, which is why I have a different
plate
loop.
Did I do something wrong here? Or do I fail to grasp the meaning of the plate
?
Thanks in advance for any piece of advice.
Giom