Image segmentation using fully convolutional network

Hello! I am trying to implement a fully convolutional neural network to produce dense segmentation outputs (e.g. semantic segmentation).

So far, I have been trying to adapt this “MNIST classification with uncertainty” tutorial to perform segmentation by changing the fully connected layers to convolution layers and the output with kernel size of one.

The network looks like this with output dimension of two for background and foreground classes:

(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1))

Basically, I just threshold MNIST images to get ground truth masks for keeping the task a simple two-class problem.

I used this answer for per-pixel image segmentation to make the batch samples and image pixels conditionally independent. If I run a step using SVI with a single batch of training data, I get the error shown below (also the code to reproduce).

I am not sure if this error has anything to do with the kernel size of one in the last layer or I am doing something wrong. I got stuck at this point and any help would be appreciated.

ValueError: Shape mismatch inside plate('height') at site class dim -3, 28 vs 128
              Trace Shapes:              
               Param Sites:              
              Sample Sites:              
 module$$$conv1.weight dist 32  1 3   3 |
                      value 32  1 3   3 |
   module$$$conv1.bias dist          32 |
                      value          32 |
 module$$$conv2.weight dist  2 32 1   1 |
                      value  2 32 1   1 |
   module$$$conv2.bias dist           2 |
                      value           2 |
                 batch dist             |
                      value         128 |
                 width dist             |
                      value          28 |
                height dist             |
                      value          28 |

The implementation is here with the pyro model and guide:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
from torch.utils.data.dataset import Dataset

import multiprocessing
import matplotlib.pyplot as plt

import pyro
from pyro.distributions import Normal, Categorical, OneHotCategorical
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.enable_validation(True)
pyro.distributions.enable_validation(True)

class FCNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, padding=1, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=out_channels, padding=0, kernel_size=1)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.softplus = nn.Softplus()

    def forward(self, x):
        output = F.relu(self.conv1(x))
        output = self.conv2(output)
        return output

    def model(self, x, y):
        conv1w_prior = Normal(loc=torch.zeros_like(self.conv1.weight),
                              scale=torch.ones_like(self.conv1.weight))
        conv1b_prior = Normal(loc=torch.zeros_like(self.conv1.bias),
                              scale=torch.ones_like(self.conv1.bias))
        conv2w_prior = Normal(loc=torch.zeros_like(self.conv2.weight),
                              scale=torch.ones_like(self.conv2.weight))
        conv2b_prior = Normal(loc=torch.zeros_like(self.conv2.bias),
                              scale=torch.ones_like(self.conv2.bias))
        priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
                  'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior}
        module = pyro.random_module('module', self, priors)
        model = module()
        lhat = self.log_softmax(model(x))
        batch_size, num_classes, width, height = x.shape
        num_classes = 2  # bg and fg
        # convert y (b, 1, w, h) to y_1_hot (b, w, h, c)
        y_1_hot = torch.eye(num_classes)[y.squeeze(1).to(torch.int64)]
        # reshape lhat (b, c, w, h) to (b, w, h, c)
        lhat = lhat.permute(0, 2, 3, 1)
        # BxWxHxC torch.Size([128, 28, 28, 2]) torch.Size([128, 28, 28, 2])
        with pyro.plate('batch', size=batch_size, dim=-4):
            with pyro.plate('width', size=width, dim=-2):
                with pyro.plate('height', size=height, dim=-3):
                    pyro.sample('class', OneHotCategorical(logits=lhat), obs=y_1_hot)

    def guide(self, x, y):
        # First layer weight distribution priors
        conv1w_mu = torch.randn_like(self.conv1.weight)
        conv1w_sigma = torch.randn_like(self.conv1.weight)
        conv1w_mu_param = pyro.param('conv1w_mu', conv1w_mu)
        conv1w_sigma_param = self.softplus(pyro.param('conv1w_sigma', conv1w_sigma))
        conv1w_prior = Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param)

        # First layer bias distribution priors
        conv1b_mu = torch.randn_like(self.conv1.bias)
        conv1b_sigma = torch.randn_like(self.conv1.bias)
        conv1b_mu_param = pyro.param('conv1b_mu', conv1b_mu)
        conv1b_sigma_param = self.softplus(pyro.param('conv1b_sigma', conv1b_sigma))
        conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param)

        # Second layer weight distribution priors
        conv2w_mu = torch.randn_like(self.conv2.weight)
        conv2w_sigma = torch.randn_like(self.conv2.weight)
        conv2w_mu_param = pyro.param('conv2w_mu', conv2w_mu)
        conv2w_sigma_param = self.softplus(pyro.param('conv2w_sigma', conv2w_sigma))
        conv2w_prior = Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param)

        # Second layer bias distribution priors
        conv2b_mu = torch.randn_like(self.conv2.bias)
        conv2b_sigma = torch.randn_like(self.conv2.bias)
        conv2b_mu_param = pyro.param('conv2b_mu', conv2b_mu)
        conv2b_sigma_param = self.softplus(pyro.param('conv2b_sigma', conv2b_sigma))
        conv2b_prior = Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param)

        priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
                  'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior}
        module = pyro.random_module('module', self, priors)
        return module()

class MNISTSeg(torch.utils.data.Dataset):
    def __init__(self):
        self.images = datasets.MNIST('mnist-data/', train=True, download=True,
                                     transform=transforms.Compose([transforms.ToTensor()]))
        self.len = len(self.images)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        img = self.images[index][0]
        gt = torch.zeros_like(img).to(torch.uint8)
        gt[img > 0] = 1
        return img, gt

train_loader = torch.utils.data.DataLoader(MNISTSeg(), batch_size=128, shuffle=True)

device = 'cpu'
net = FCNN(1, 2)
print(net)
for name, _ in net.named_parameters():
    print(name)

optim = Adam({"lr": 0.01})
svi = SVI(net.model, net.guide, optim, loss=Trace_ELBO())

# one batch only
data = next(iter(train_loader))
loss = svi.step(data[0].to(device), data[1].to(device))
print(loss)

@cnx I guess you can try:

conv1w_prior = Normal(..).to_event()

and similar for other priors/guides.