Mixture model

Hi,
I am trying to build a mixture model modified based on the GMM tutorial but I am getting shape error when using plate. The code is as given below.

class BCNN(PyroModule):
    def __init__(self,n_gaussians):
        super().__init__()
        self.cnn1 = PyroModule[nn.Conv2d](in_channels=8, out_channels=16, kernel_size=3,padding=1,stride=1)
        self.cnn1.wieght=PyroSample(Normal(0,1.).expand([16,8]).to_event(2))
        self.cnn1.bias=PyroSample(Normal(0,1.).expand([16]).to_event(1))
        #Maxpool and reduce 
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.act1=nn.ELU()
        
        self.cnn2 = PyroModule[nn.Conv2d](in_channels=16, out_channels=32, kernel_size=3,padding=1,stride=1)
        self.cnn2.wieght=PyroSample(Normal(0,1.).expand([32,16]).to_event(2))
        self.cnn2.bias=PyroSample(Normal(0,1.).expand([32]).to_event(1))
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.act2=nn.ELU()
        
        #The final layer with 
        self.z_mu = PyroModule[nn.Linear](11200,n_gaussians)   #200
        self.z_mu.weight=PyroSample(Normal(0,0.7).expand([n_gaussians,11200]).to_event(2))
        self.z_mu.bias=PyroSample(Normal(0,0.7).expand([n_gaussians]).to_event(1))
    
    def forward(self,x):
        out=self.act1(self.maxpool1(self.cnn1(x)))
        out=self.act2(self.maxpool2(self.cnn2(out)))
        z_h=out.view(out.size(0),-1)
        mu = self.z_mu(z_h) #mean 
        return(mu)

    @config_enumerate
    def model(self,x,y=None):
        mu=self.forward(x)
        sigma=pyro.sample("sigma",LogNormal(0.,n_gaussians))
        pi=pyro.sample("pi",Dirichlet(0.5 * torch.ones(n_gaussians)))
        with pyro.plate("data",x.shape[0]):
            cat = pyro.sample('cat', Categorical(pi))
            obs= pyro.sample('obs', Normal(mu[cat],sigma).to_event(),obs=y)
        return (obs)

network = BCNN(n_gaussians=2)
adam_params = {"lr": 0.000001,"betas": (0.9, 0.999)}
optimizer = Adam(adam_params)
elbo = TraceEnum_ELBO(max_plate_nesting=1)
guide=AutoDiagonalNormal(poutine.block(network.model, hide=['cat']))
svi = SVI(network.model, guide, optimizer, loss=elbo)
losses = []
n_iters = 5000
batch_size=20
train_data = data_utils.TensorDataset(xtrain,ytrain)
train_loader = torch.utils.data.DataLoader(dataset=train_data, 
                                           batch_size=20, 
                                           shuffle=True)
 num_epochs = n_iters / (len(train_data) / batch_size)
 num_epochs = int(num_epochs)
    
 pyro.clear_param_store()
  for epoch in range(num_epochs):
        loss=0.0
        total_epoch_loss_train=0.0
        for i, (x, y) in enumerate(train_loader):
            x = x.requires_grad_()
            loss += svi.step(x.float(), y)
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = loss / normalizer_train  
        if epoch % 100 == 0:
            print(epoch, total_epoch_loss_train)
        losses.append(total_epoch_loss_train)
Error:shape '[1]' is invalid for input of size 0
Trace Shapes:                  
                   Param Sites:                  
         AutoDiagonalNormal.loc 22451            
       AutoDiagonalNormal.scale 22451            
                  Sample Sites:                  
_AutoDiagonalNormal_latent dist     | 22451      
                          value                       | 22451      
                 cnn1.bias dist                  |    16      
                          value                      |    16      
                 cnn2.bias dist                 |    32      
                          value                     |    32      
               z_mu.weight dist              |     2 11200
                          value                     |     2 11200
                 z_mu.bias dist               |     2      
                          value                    |     2      
                     sigma dist                 |            
                          value                    |            
                  Trace Shapes:                  
                   Param Sites:                  
         AutoDiagonalNormal.loc 22451            
       AutoDiagonalNormal.scale 22451            
                  Sample Sites:                  
_AutoDiagonalNormal_latent dist     | 22451      
                          value                       | 22451      
                 cnn1.bias dist                  |    16      
                          value                      |    16      
                 cnn2.bias dist                 |    32      
                          value                     |    32      
               z_mu.weight dist             |     2 11200
                          value                    |     2 11200
                 z_mu.bias dist               |     2      
                          value                    |     2      
                     sigma dist                 |            
                          value                    |

I think you’ll want to remove that final .to_event() in the observe statement

- Normal(mu[cat],sigma).to_event()
+ Normal(mu[cat],sigma)

Thank you Fritzo that helped.