I am new to Pyro. I would like to know if Pyro works for the inference of graphical models. I want to first generate a simple case.
The discrete Bayesian Network has three nodes: A, B, C and all of them are binary variables. The graphical structure is B \leftarrow A \rightarrowC. Hence, the joint probability can be expressed by
P(A, B, C) = P(A)P(B|A)P(C|A). Now, we have some observations of A, B, C, i.e. [0,1,0], [1,1,1],[0,0,1] … The goal is to compute A* = argmax_{A}P(A|B=0).
As we know, A, B, C all have the observations and there are also some dependencies among them. In a traditional way, we need to perform parameter learning of the conditional probability p(A), P(B|A), P(C|A) and then do inference. I am very interested in what Pyro can do!
Thank you so much for any help and attention.
I also wrote some codes but it does not work. ‘RuntimeError: result type Float can’t be cast to the desired output type Int’. I do not know if it is on the right track or not.
def bn(data): pa = pyro.sample(‘pa’,dist.Beta(1,1)) pb0 = pyro.sample(‘pb0’,dist.Beta(1,1)) pb1 = pyro.sample(‘pb1’, dist.Beta(1,1)) pc0 = pyro.sample(‘pc0’, dist.Beta(1,1)) pc1 = pyro.sample(‘pc1’, dist.Beta(1,1)) for i in pyro.plate(“data_loop”, len(data)): A =pyro.sample(‘A_{}’.format(i),dist.Bernoulli(pa), obs = torch.tensor(data[:,0][i],dtype=torch.int)) if A==0: B = pyro.sample(‘B_{}’.format(i),dist.Bernoulli(pb0), obs = torch.tensor(data[:,1][ i]) C = pyro.sample(‘C_{}’.format(i),dist.Bernoulli(pc0),obs = torch.tensor(data[:,2][i])) if A==1: B = pyro.sample(‘B_{}’.format(i),dist.Bernoulli(pb1), obs = torch.tensor(data[:,1][i]) C = pyro.sample(‘C_{}’.format(i),dist.Bernoulli(pc1), obs = torch.tensor(data[:,2][i]) return B,C def guide(data): p = pyro.param(“p”, torch.randn(5, 2).exp(), constraint=constraints.simplex) pa = pyro.sample(‘pa’, dist.Beta(p[0,0],p[0,1])) pb0 = pyro.sample(‘pb0’, dist.Beta(p[1,0],p[1,1])) pb1 = pyro.sample(‘pb1’, dist.Beta(p[2,0],p[2,1])) pc0 = pyro.sample(‘pc0’, dist.Beta(p[3,0],p[3,1])) pc1 = pyro.sample(‘pc1’, dist.Beta(p[4,0],p[4,1])) return pa, pb0, pb1, pc0, pc1