Hi everyone, I’m very new to probabilistic graphical model and probabilistic machine learning. I’m learning Pyro for my master thesis and have following problem.
I code a model function as following:
def model(self, x, next_state): pyro.module("decoder", self.decoder) batch_size = x.shape with pyro.plate("data", batch_size): # prior belief of alpha distribution alpha_loc = torch.zeros(3) alpha_scale = torch.eye(3) alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale)) probs = pyro.sample("probs", dist.Dirichlet(alpha)) z = pyro.sample("discrete_latent", dist.Multinomial(probs=probs))
And encounter this output:
Sample Sites: data dist | value 5 | alpha dist 5 | 3 value 5 | 3 probs dist 5 | 3 value 5 5 | 3 discrete_latent dist 5 | 5 3 value 5 | 5 3
As you can see the output of
probs has shape
(5,5,3). But when I run the same code in another cell in jupyter notebook, I will have following output:
with pyro.plate("data", 5): alpha_loc = torch.zeros(3) alpha_scale = torch.eye(3) alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale)) print(alpha.shape) probs = pyro.sample("probs", dist.Dirichlet(alpha)) print(probs)
torch.Size([5, 3]) tensor([[3.0978e-01, 3.3674e-02, 6.5654e-01], [1.1755e-38, 9.9533e-01, 4.6714e-03], [9.7650e-01, 1.1755e-38, 2.3500e-02], [1.1755e-38, 1.0000e+00, 4.7411e-07], [3.3333e-01, 3.3333e-01, 3.3333e-01]])
The latter output is what I want but I don’t know what makes my function wrong. Thank you for helping me.