I am trying to implement a simple CNN for MNIST classification. It doesn’t work properly, If I use nn.Linear instead of nn.Conv2d in the model, it works fine. I don’t know why. Ask for help.
the codes are as following:
import os
import torch
from torch import nn
import torch.nn.functional as F
import pyro
from pyro.nn import PyroModule, PyroSample
from pyro.distributions import Normal, Categorical
from pyro.infer.autoguide.guides import AutoNormal
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, Predictive
from utils import PrefetcherDataLoader
from torchvision import datasets, transforms
class ConvNet(PyroModule):
def __init__(self, init_loc=torch.tensor(0.), init_scale=torch.tensor(1.), use_cuda=True):
super(ConvNet, self).__init__()
self.conv1 = PyroModule[nn.Conv2d](1, 10, 5)
self.conv2 = PyroModule[nn.Conv2d](10, 20, 5)
self.fc1 = PyroModule[nn.Linear](320, 50)
self.fc2 = PyroModule[nn.Linear](50, 10)
if use_cuda:
init_loc = init_loc.cuda()
init_scale = init_scale.cuda()
self.cuda()
for m in self.modules():
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, PyroSample(prior=Normal(init_loc, init_scale).expand(value.shape).to_event(value.dim())))
def forward(self, x, y=None):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
with pyro.plate("observe_data"):
obs = pyro.sample('obs', Categorical(logits=x), obs=y)
return obs
model = ConvNet()
guide = AutoNormal(model)
svi = SVI(model, guide, Adam({'lr': 0.001}), loss=Trace_ELBO())
predictive = Predictive(model, guide=guide, num_samples=20, return_sites=('obs',))
if os.path.exists('BayesianNet.save'):
pyro.clear_param_store()
pyro.get_param_store().load('BayesianNet.save')
train_loader = PrefetcherDataLoader(datasets.MNIST('mnist-data/', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.13066045939922333,), std=(0.30810779333114624,))])), batch_size=128, shuffle=True, pin_memory=True)
test_loader = PrefetcherDataLoader(datasets.MNIST('mnist-data/', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.13066045939922333,), std=(0.30810779333114624,))])), batch_size=512, shuffle=False, pin_memory=True)
EPOCHS = 10
LOG_INTERVAL = 15
for epoch in range(1, EPOCHS + 1):
train_losses = 0
train_batchs = len(train_loader)
for batchs_num, (x, y) in enumerate(train_loader):
loss = svi.step(x, y)
train_losses += loss
if (batchs_num + 1) % LOG_INTERVAL == 0 or batchs_num + 1 == train_batchs:
print(f'\rtrain epoch: [{epoch}/{EPOCHS}], batch: [{batchs_num + 1}/{train_batchs}] loss: {train_losses / (batchs_num + 1): .6f}.', end='')
print()
rights, refuses = 0, 0
test_batchs = len(test_loader)
for batchs_num, (x, y) in enumerate(test_loader):
with torch.no_grad():
dict_samples = predictive(x, None)
obs = dict_samples['obs'].T
probs, preds = [], []
for i in range(obs.shape[0]):
histogram = obs[i].histc(bins=10, min=0, max=9)
histogram = histogram / histogram.sum()
prob, pred = histogram.max(0)
probs.append(prob)
preds.append(pred)
probs, preds = torch.stack(probs), torch.stack(preds)
refuse_threshold = 0.2 # TODO 这里设定阈值,如果预测值不可信则丢弃
right_rate = (preds[probs > refuse_threshold] == y[probs > refuse_threshold]).sum() / y.shape[0]
refuse_rate = (probs <= refuse_threshold).sum() / y.shape[0]
rights += right_rate
refuses += refuse_rate
print(f'\rtest epoch: [{epoch + 1}/{EPOCHS}], batch: [{batchs_num + 1}/{test_batchs}] right: {rights/(batchs_num + 1)*100:.2f}, refuse: {refuses/(batchs_num + 1)*100:.2f}.', end='')
print()
pyro.get_param_store().save('BayesianNet.save')
outputs:
train epoch: [1/10], batch: [469/469] loss: 34800.064756.
test epoch: [2/10], batch: [20/20] right: 4.12, refuse: 60.75.
train epoch: [2/10], batch: [469/469] loss: 25019.466096.
test epoch: [3/10], batch: [20/20] right: 4.19, refuse: 60.78.
train epoch: [3/10], batch: [469/469] loss: 16110.607585.
test epoch: [4/10], batch: [20/20] right: 4.00, refuse: 60.77.
train epoch: [4/10], batch: [469/469] loss: 9053.331933..
test epoch: [5/10], batch: [20/20] right: 4.39, refuse: 59.08.
test epoch: [6/10], batch: [20/20] right: 3.82, refuse: 59.95.
train epoch: [6/10], batch: [469/469] loss: 2758.916881.
test epoch: [7/10], batch: [20/20] right: 3.95, refuse: 59.95.
train epoch: [7/10], batch: [469/469] loss: 2096.815076.
test epoch: [8/10], batch: [20/20] right: 3.84, refuse: 60.05.
train epoch: [8/10], batch: [469/469] loss: 1847.571749.
test epoch: [9/10], batch: [20/20] right: 4.03, refuse: 59.99.
train epoch: [9/10], batch: [469/469] loss: 1647.313930.
test epoch: [10/10], batch: [20/20] right: 3.99, refuse: 60.70.
train epoch: [10/10], batch: [270/469] loss: 1550.818323.