Using CNN for Inference (SVI)

I am trying to implement a simple CNN for MNIST classification. It doesn’t work properly, If I use nn.Linear instead of nn.Conv2d in the model, it works fine. I don’t know why. Ask for help.

the codes are as following:

import os
import torch
from torch import nn
import torch.nn.functional as F
import pyro
from pyro.nn import PyroModule, PyroSample
from pyro.distributions import Normal, Categorical
from pyro.infer.autoguide.guides import AutoNormal
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, Predictive
from utils import PrefetcherDataLoader
from torchvision import datasets, transforms

class ConvNet(PyroModule):

    def __init__(self, init_loc=torch.tensor(0.), init_scale=torch.tensor(1.), use_cuda=True):
        super(ConvNet, self).__init__()
        self.conv1 = PyroModule[nn.Conv2d](1, 10, 5)
        self.conv2 = PyroModule[nn.Conv2d](10, 20, 5)
        self.fc1 = PyroModule[nn.Linear](320, 50)
        self.fc2 = PyroModule[nn.Linear](50, 10)

        if use_cuda:
            init_loc = init_loc.cuda()
            init_scale = init_scale.cuda()
            self.cuda()
        for m in self.modules():
            for name, value in list(m.named_parameters(recurse=False)):
                setattr(m, name, PyroSample(prior=Normal(init_loc, init_scale).expand(value.shape).to_event(value.dim())))

    def forward(self, x, y=None):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        with pyro.plate("observe_data"):
            obs = pyro.sample('obs', Categorical(logits=x), obs=y)
        return obs

model = ConvNet()
guide = AutoNormal(model)

svi = SVI(model, guide, Adam({'lr': 0.001}), loss=Trace_ELBO())
predictive = Predictive(model, guide=guide, num_samples=20, return_sites=('obs',))

if os.path.exists('BayesianNet.save'):
    pyro.clear_param_store()
    pyro.get_param_store().load('BayesianNet.save')

train_loader = PrefetcherDataLoader(datasets.MNIST('mnist-data/', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.13066045939922333,), std=(0.30810779333114624,))])), batch_size=128, shuffle=True, pin_memory=True)
test_loader = PrefetcherDataLoader(datasets.MNIST('mnist-data/', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.13066045939922333,), std=(0.30810779333114624,))])), batch_size=512, shuffle=False, pin_memory=True)

EPOCHS = 10
LOG_INTERVAL = 15
for epoch in range(1, EPOCHS + 1):
    train_losses = 0
    train_batchs = len(train_loader)
    for batchs_num, (x, y) in enumerate(train_loader):
        loss = svi.step(x, y)
        train_losses += loss
        if (batchs_num + 1) % LOG_INTERVAL == 0 or batchs_num + 1 == train_batchs:
            print(f'\rtrain epoch: [{epoch}/{EPOCHS}], batch: [{batchs_num + 1}/{train_batchs}] loss: {train_losses / (batchs_num + 1): .6f}.', end='')
    print()
    rights, refuses = 0, 0
    test_batchs = len(test_loader)
    for batchs_num, (x, y) in enumerate(test_loader):
        with torch.no_grad():
            dict_samples = predictive(x, None)
        obs = dict_samples['obs'].T
        probs, preds = [], []
        for i in range(obs.shape[0]):
            histogram = obs[i].histc(bins=10, min=0, max=9)
            histogram = histogram / histogram.sum()
            prob, pred = histogram.max(0)
            probs.append(prob)
            preds.append(pred)
        probs, preds = torch.stack(probs), torch.stack(preds)
        refuse_threshold = 0.2  # TODO 这里设定阈值,如果预测值不可信则丢弃
        right_rate = (preds[probs > refuse_threshold] == y[probs > refuse_threshold]).sum() / y.shape[0]
        refuse_rate = (probs <= refuse_threshold).sum() / y.shape[0]
        rights += right_rate
        refuses += refuse_rate
        print(f'\rtest epoch: [{epoch + 1}/{EPOCHS}], batch: [{batchs_num + 1}/{test_batchs}] right: {rights/(batchs_num + 1)*100:.2f}, refuse: {refuses/(batchs_num + 1)*100:.2f}.', end='')
    print()

    pyro.get_param_store().save('BayesianNet.save')

outputs:

train epoch: [1/10], batch: [469/469] loss:  34800.064756.
test epoch: [2/10], batch: [20/20] right: 4.12, refuse: 60.75.
train epoch: [2/10], batch: [469/469] loss:  25019.466096.
test epoch: [3/10], batch: [20/20] right: 4.19, refuse: 60.78.
train epoch: [3/10], batch: [469/469] loss:  16110.607585.
test epoch: [4/10], batch: [20/20] right: 4.00, refuse: 60.77.
train epoch: [4/10], batch: [469/469] loss:  9053.331933..
test epoch: [5/10], batch: [20/20] right: 4.39, refuse: 59.08.
test epoch: [6/10], batch: [20/20] right: 3.82, refuse: 59.95.
train epoch: [6/10], batch: [469/469] loss:  2758.916881.
test epoch: [7/10], batch: [20/20] right: 3.95, refuse: 59.95.
train epoch: [7/10], batch: [469/469] loss:  2096.815076.
test epoch: [8/10], batch: [20/20] right: 3.84, refuse: 60.05.
train epoch: [8/10], batch: [469/469] loss:  1847.571749.
test epoch: [9/10], batch: [20/20] right: 4.03, refuse: 59.99.
train epoch: [9/10], batch: [469/469] loss:  1647.313930.
test epoch: [10/10], batch: [20/20] right: 3.99, refuse: 60.70.
train epoch: [10/10], batch: [270/469] loss:  1550.818323.