Hi everyone, I’m very new to probabilistic graphical model and probabilistic machine learning. I’m learning Pyro for my master thesis and have following problem.
I code a model function as following:
def model(self, x, next_state):
pyro.module("decoder", self.decoder)
batch_size = x.shape[0]
with pyro.plate("data", batch_size):
# prior belief of alpha distribution
alpha_loc = torch.zeros(3)
alpha_scale = torch.eye(3)
alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale))
probs = pyro.sample("probs", dist.Dirichlet(alpha))
z = pyro.sample("discrete_latent", dist.Multinomial(probs=probs))
And encounter this output:
Sample Sites:
data dist |
value 5 |
alpha dist 5 | 3
value 5 | 3
probs dist 5 | 3
value 5 5 | 3
discrete_latent dist 5 | 5 3
value 5 | 5 3
As you can see the output of probs
has shape (5,5,3)
. But when I run the same code in another cell in jupyter notebook, I will have following output:
with pyro.plate("data", 5):
alpha_loc = torch.zeros(3)
alpha_scale = torch.eye(3)
alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale))
print(alpha.shape)
probs = pyro.sample("probs", dist.Dirichlet(alpha))
print(probs)
torch.Size([5, 3])
tensor([[3.0978e-01, 3.3674e-02, 6.5654e-01],
[1.1755e-38, 9.9533e-01, 4.6714e-03],
[9.7650e-01, 1.1755e-38, 2.3500e-02],
[1.1755e-38, 1.0000e+00, 4.7411e-07],
[3.3333e-01, 3.3333e-01, 3.3333e-01]])
The latter output is what I want but I don’t know what makes my function wrong. Thank you for helping me.