Question about tensor shape

Hi everyone, I’m very new to probabilistic graphical model and probabilistic machine learning. I’m learning Pyro for my master thesis and have following problem.

I code a model function as following:

def model(self, x, next_state):
    pyro.module("decoder", self.decoder)
    batch_size = x.shape[0]
    with pyro.plate("data", batch_size):

        # prior belief of alpha distribution
        alpha_loc = torch.zeros(3)
        alpha_scale = torch.eye(3)
        alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale))
        probs = pyro.sample("probs", dist.Dirichlet(alpha))
        z = pyro.sample("discrete_latent", dist.Multinomial(probs=probs))

And encounter this output:

Sample Sites:              
            data dist         |    
                value     5   |    
           alpha dist     5   | 3  
                value     5   | 3  
           probs dist     5   | 3  
                value 5   5   | 3  
 discrete_latent dist     5   | 5 3
                value     5   | 5 3

As you can see the output of probs has shape (5,5,3). But when I run the same code in another cell in jupyter notebook, I will have following output:

with pyro.plate("data", 5):
    alpha_loc = torch.zeros(3)
    alpha_scale = torch.eye(3)
    alpha = pyro.sample("alpha", dist.MultivariateNormal(alpha_loc, alpha_scale))
    print(alpha.shape)
    probs = pyro.sample("probs", dist.Dirichlet(alpha))
    print(probs)
torch.Size([5, 3])
tensor([[3.0978e-01, 3.3674e-02, 6.5654e-01],
        [1.1755e-38, 9.9533e-01, 4.6714e-03],
        [9.7650e-01, 1.1755e-38, 2.3500e-02],
        [1.1755e-38, 1.0000e+00, 4.7411e-07],
        [3.3333e-01, 3.3333e-01, 3.3333e-01]])

The latter output is what I want but I don’t know what makes my function wrong. Thank you for helping me.

the output of probs has shape (5,5,3)

This looks strange… Could you try:

  • with pyro.plate("data", batch_size, dim=-1)
  • dist.Multinomial(probs=probs.clone()) to see if the there is inplace operator somewhere…

It would be easier to debug if you isolate the issue by removing unnecessary pieces such as decoder and make a full reproducible code.