Question about tensor shape

Hi everyone, I’m very new to probabilistic graphical model and probabilistic machine learning. I’m learning Pyro for my master thesis and have following problem.

I code a model function as following:

``````def model(self, x, next_state):
pyro.module("decoder", self.decoder)
batch_size = x.shape[0]
with pyro.plate("data", batch_size):

# prior belief of alpha distribution
alpha_loc = torch.zeros(3)
alpha_scale = torch.eye(3)
alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale))
probs = pyro.sample("probs", dist.Dirichlet(alpha))
z = pyro.sample("discrete_latent", dist.Multinomial(probs=probs))
``````

And encounter this output:

``````Sample Sites:
data dist         |
value     5   |
alpha dist     5   | 3
value     5   | 3
probs dist     5   | 3
value 5   5   | 3
discrete_latent dist     5   | 5 3
value     5   | 5 3
``````

As you can see the output of `probs` has shape `(5,5,3)`. But when I run the same code in another cell in jupyter notebook, I will have following output:

``````with pyro.plate("data", 5):
alpha_loc = torch.zeros(3)
alpha_scale = torch.eye(3)
alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale))
print(alpha.shape)
probs = pyro.sample("probs", dist.Dirichlet(alpha))
print(probs)
``````
``````torch.Size([5, 3])
tensor([[3.0978e-01, 3.3674e-02, 6.5654e-01],
[1.1755e-38, 9.9533e-01, 4.6714e-03],
[9.7650e-01, 1.1755e-38, 2.3500e-02],
[1.1755e-38, 1.0000e+00, 4.7411e-07],
[3.3333e-01, 3.3333e-01, 3.3333e-01]])
``````

The latter output is what I want but I don’t know what makes my function wrong. Thank you for helping me.

the output of `probs` has shape `(5,5,3)`

This looks strange… Could you try:

• `with pyro.plate("data", batch_size, dim=-1)`
• `dist.Multinomial(probs=probs.clone())` to see if the there is inplace operator somewhere…

It would be easier to debug if you isolate the issue by removing unnecessary pieces such as `decoder` and make a full reproducible code.