Hello! I am trying to implement a fully convolutional neural network to produce dense segmentation outputs (e.g. semantic segmentation).
So far, I have been trying to adapt this “MNIST classification with uncertainty” tutorial to perform segmentation by changing the fully connected layers to convolution layers and the output with kernel size of one.
The network looks like this with output dimension of two for background and foreground classes:
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1))
Basically, I just threshold MNIST images to get ground truth masks for keeping the task a simple two-class problem.
I used this answer for per-pixel image segmentation to make the batch samples and image pixels conditionally independent. If I run a step using SVI with a single batch of training data, I get the error shown below (also the code to reproduce).
I am not sure if this error has anything to do with the kernel size of one in the last layer or I am doing something wrong. I got stuck at this point and any help would be appreciated.
ValueError: Shape mismatch inside plate('height') at site class dim -3, 28 vs 128
Trace Shapes:
Param Sites:
Sample Sites:
module$$$conv1.weight dist 32 1 3 3 |
value 32 1 3 3 |
module$$$conv1.bias dist 32 |
value 32 |
module$$$conv2.weight dist 2 32 1 1 |
value 2 32 1 1 |
module$$$conv2.bias dist 2 |
value 2 |
batch dist |
value 128 |
width dist |
value 28 |
height dist |
value 28 |
The implementation is here with the pyro model and guide:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from torch.utils.data.dataset import Dataset
import multiprocessing
import matplotlib.pyplot as plt
import pyro
from pyro.distributions import Normal, Categorical, OneHotCategorical
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
pyro.enable_validation(True)
pyro.distributions.enable_validation(True)
class FCNN(nn.Module):
def __init__(self, in_channels, out_channels):
super(FCNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, padding=1, kernel_size=3)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=out_channels, padding=0, kernel_size=1)
self.log_softmax = nn.LogSoftmax(dim=1)
self.softplus = nn.Softplus()
def forward(self, x):
output = F.relu(self.conv1(x))
output = self.conv2(output)
return output
def model(self, x, y):
conv1w_prior = Normal(loc=torch.zeros_like(self.conv1.weight),
scale=torch.ones_like(self.conv1.weight))
conv1b_prior = Normal(loc=torch.zeros_like(self.conv1.bias),
scale=torch.ones_like(self.conv1.bias))
conv2w_prior = Normal(loc=torch.zeros_like(self.conv2.weight),
scale=torch.ones_like(self.conv2.weight))
conv2b_prior = Normal(loc=torch.zeros_like(self.conv2.bias),
scale=torch.ones_like(self.conv2.bias))
priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior}
module = pyro.random_module('module', self, priors)
model = module()
lhat = self.log_softmax(model(x))
batch_size, num_classes, width, height = x.shape
num_classes = 2 # bg and fg
# convert y (b, 1, w, h) to y_1_hot (b, w, h, c)
y_1_hot = torch.eye(num_classes)[y.squeeze(1).to(torch.int64)]
# reshape lhat (b, c, w, h) to (b, w, h, c)
lhat = lhat.permute(0, 2, 3, 1)
# BxWxHxC torch.Size([128, 28, 28, 2]) torch.Size([128, 28, 28, 2])
with pyro.plate('batch', size=batch_size, dim=-4):
with pyro.plate('width', size=width, dim=-2):
with pyro.plate('height', size=height, dim=-3):
pyro.sample('class', OneHotCategorical(logits=lhat), obs=y_1_hot)
def guide(self, x, y):
# First layer weight distribution priors
conv1w_mu = torch.randn_like(self.conv1.weight)
conv1w_sigma = torch.randn_like(self.conv1.weight)
conv1w_mu_param = pyro.param('conv1w_mu', conv1w_mu)
conv1w_sigma_param = self.softplus(pyro.param('conv1w_sigma', conv1w_sigma))
conv1w_prior = Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param)
# First layer bias distribution priors
conv1b_mu = torch.randn_like(self.conv1.bias)
conv1b_sigma = torch.randn_like(self.conv1.bias)
conv1b_mu_param = pyro.param('conv1b_mu', conv1b_mu)
conv1b_sigma_param = self.softplus(pyro.param('conv1b_sigma', conv1b_sigma))
conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param)
# Second layer weight distribution priors
conv2w_mu = torch.randn_like(self.conv2.weight)
conv2w_sigma = torch.randn_like(self.conv2.weight)
conv2w_mu_param = pyro.param('conv2w_mu', conv2w_mu)
conv2w_sigma_param = self.softplus(pyro.param('conv2w_sigma', conv2w_sigma))
conv2w_prior = Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param)
# Second layer bias distribution priors
conv2b_mu = torch.randn_like(self.conv2.bias)
conv2b_sigma = torch.randn_like(self.conv2.bias)
conv2b_mu_param = pyro.param('conv2b_mu', conv2b_mu)
conv2b_sigma_param = self.softplus(pyro.param('conv2b_sigma', conv2b_sigma))
conv2b_prior = Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param)
priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior}
module = pyro.random_module('module', self, priors)
return module()
class MNISTSeg(torch.utils.data.Dataset):
def __init__(self):
self.images = datasets.MNIST('mnist-data/', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor()]))
self.len = len(self.images)
def __len__(self):
return self.len
def __getitem__(self, index):
img = self.images[index][0]
gt = torch.zeros_like(img).to(torch.uint8)
gt[img > 0] = 1
return img, gt
train_loader = torch.utils.data.DataLoader(MNISTSeg(), batch_size=128, shuffle=True)
device = 'cpu'
net = FCNN(1, 2)
print(net)
for name, _ in net.named_parameters():
print(name)
optim = Adam({"lr": 0.01})
svi = SVI(net.model, net.guide, optim, loss=Trace_ELBO())
# one batch only
data = next(iter(train_loader))
loss = svi.step(data[0].to(device), data[1].to(device))
print(loss)