Mixture model

Hi,
I am trying to build a mixture model modified based on the GMM tutorial but I am getting shape error when using plate. The code is as given below.

class BCNN(PyroModule):
def init(self,n_gaussians):
super().init()
self.cnn1 = PyroModule[nn.Conv2d](in_channels=8, out_channels=16, kernel_size=3,padding=1,stride=1)
self.cnn1.wieght=PyroSample(Normal(0,1.).expand([16,8]).to_event(2))
self.cnn1.bias=PyroSample(Normal(0,1.).expand([16]).to_event(1))
#Maxpool and reduce
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.act1=nn.ELU()

    self.cnn2 = PyroModule[nn.Conv2d](in_channels=16, out_channels=32, kernel_size=3,padding=1,stride=1)
    self.cnn2.wieght=PyroSample(Normal(0,1.).expand([32,16]).to_event(2))
    self.cnn2.bias=PyroSample(Normal(0,1.).expand([32]).to_event(1))
    self.maxpool2 = nn.MaxPool2d(kernel_size=2)
    self.act2=nn.ELU()
    
    #The final layer with 
    self.z_mu = PyroModule[nn.Linear](11200,n_gaussians)   #200
    self.z_mu.weight=PyroSample(Normal(0,0.7).expand([n_gaussians,11200]).to_event(2))
    self.z_mu.bias=PyroSample(Normal(0,0.7).expand([n_gaussians]).to_event(1))

def forward(self,x):
    out=self.act1(self.maxpool1(self.cnn1(x)))
    out=self.act2(self.maxpool2(self.cnn2(out)))
    z_h=out.view(out.size(0),-1)
    mu = self.z_mu(z_h) #mean 
    return(mu)

@config_enumerate
def model(self,x,y=None):
    mu=self.forward(x)
    sigma=pyro.sample("sigma",LogNormal(0.,n_gaussians))
    pi=pyro.sample("pi",Dirichlet(0.5 * torch.ones(n_gaussians)))
    with pyro.plate("data",x.shape[0]):
        cat = pyro.sample('cat', Categorical(pi))
        obs= pyro.sample('obs', Normal(mu[cat],sigma).to_event(),obs=y)
    return (obs)

network = BCNN(n_gaussians=2)
adam_params = {“lr”: 0.000001,“betas”: (0.9, 0.999)}
optimizer = Adam(adam_params)
elbo = TraceEnum_ELBO(max_plate_nesting=1)
guide=AutoDiagonalNormal(poutine.block(network.model, hide=[‘cat’]))
svi = SVI(network.model, guide, optimizer, loss=elbo)
losses = []
n_iters = 5000
batch_size=20
train_data = data_utils.TensorDataset(xtrain,ytrain)
train_loader = torch.utils.data.DataLoader(dataset=train_data,
batch_size=20,
shuffle=True)
num_epochs = n_iters / (len(train_data) / batch_size)
num_epochs = int(num_epochs)

pyro.clear_param_store()
for epoch in range(num_epochs):
loss=0.0
total_epoch_loss_train=0.0
for i, (x, y) in enumerate(train_loader):
x = x.requires_grad_()
loss += svi.step(x.float(), y)
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = loss / normalizer_train
if epoch % 100 == 0:
print(epoch, total_epoch_loss_train)
losses.append(total_epoch_loss_train)

Error:shape ‘[1]’ is invalid for input of size 0
Trace Shapes:
Param Sites:
AutoDiagonalNormal.loc 22451
AutoDiagonalNormal.scale 22451
Sample Sites:
_AutoDiagonalNormal_latent dist | 22451
value | 22451
cnn1.bias dist | 16
value | 16
cnn2.bias dist | 32
value | 32
z_mu.weight dist | 2 11200
value | 2 11200
z_mu.bias dist | 2
value | 2
sigma dist |
value |
Trace Shapes:
Param Sites:
AutoDiagonalNormal.loc 22451
AutoDiagonalNormal.scale 22451
Sample Sites:
_AutoDiagonalNormal_latent dist | 22451
value | 22451
cnn1.bias dist | 16
value | 16
cnn2.bias dist | 32
value | 32
z_mu.weight dist | 2 11200
value | 2 11200
z_mu.bias dist | 2
value | 2
sigma dist |
value |