Beginners Help

Hello there.

I am trying to finish a project for my college course and I am getting a bit stuck when it comes to using pyro with a gpu. The project is to take Chest Xray images of covid and pneumonia patients and use a Bayesian Neural Network to classify them. I was able to train the model but now when I try to determine which images the BNN is uncertain about, I run out of GPU ram and I am not entirely sure why. I am very new to this subject so any help is appreciated. The error seems to be occurring in the BNN forward function. The Code I have is found below:

# Import relevant packages
import os
import random
import torch
import torch.nn.functional as nnf
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
from torch.optim import SGD 
from torch.distributions import constraints
import torchvision as torchv
import torchvision.transforms as torchvt
from torchvision.datasets.mnist import MNIST
from torch import nn
from pyro.infer import SVI, TraceMeanField_ELBO
import pyro
from pyro import poutine
import pyro.optim as pyroopt
import pyro.distributions as dist
import pyro.contrib.bnn as bnn
import matplotlib.pyplot as plt
import seaborn as sns
from torch.distributions.utils import lazy_property
import math
from PIL import Image
%matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs, transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
            print(f'Found {len(images)} {class_name} examples')
            return images
        
        self.images = {}
        self.class_names = ['normal', 'viral', 'covid']
        
        for class_name in self.class_names:
            self.images[class_name] = get_images(class_name)
            
        self.image_dirs = image_dirs
        self.transform = transform
        
    
    def __len__(self):
        return sum([len(self.images[class_name]) for class_name in self.class_names])
    
    
    def __getitem__(self, index):
        class_name = random.choice(self.class_names)
        index = index % len(self.images[class_name])
        image_name = self.images[class_name][index]
        image_path = os.path.join(self.image_dirs[class_name], image_name)
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), self.class_names.index(class_name)
train_transform = torchvt.Compose([
    torchvt.Resize(size=(128, 128)),
    torchvt.RandomHorizontalFlip(),
    torchvt.Grayscale(num_output_channels=1),
    torchvt.ToTensor(),
    torchvt.Normalize((0.5, ), (0.5, ))
])

test_transform = torchvt.Compose([
    torchvt.Resize(size=(128, 128)),
    torchvt.Grayscale(num_output_channels=1),
    torchvt.ToTensor(),
    torchvt.Normalize((0.5, ), (0.5, ))
])
train_dirs = {
    'normal': f'{root_dir}/normal',
    'viral': f'{root_dir}/viral',
    'covid': f'{root_dir}/covid'
}
train_dataset = ChestXRayDataset(train_dirs, train_transform)
test_dirs = {
    'normal': f'{root_dir}/test/normal',
    'viral': f'{root_dir}/test/viral',
    'covid': f'{root_dir}/test/covid'
}

test_dataset = ChestXRayDataset(test_dirs, test_transform)
batch_size = 6

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
class BNN(nn.Module):
    def __init__(self, n_hidden=1024, n_classes=10):
        super(BNN, self).__init__()
        self.n_hidden = n_hidden
        self.n_classes = n_classes
       
    def model(self, images, labels=None, kl_factor=1.0): 
        images = images.to(device).view(-1, 128*128)
        n_images = images.size(0)
        # Set-up parameters for the distribution of weights for each layer `a<n>`
        a1_mean = torch.zeros(128*128, self.n_hidden, dtype=torch.float64, device=device)
        a1_scale = torch.ones(128*128, self.n_hidden, dtype=torch.float64, device=device)
        a1_dropout = torch.tensor(0.25, dtype=torch.float64, device=device)
        a2_mean = torch.zeros(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
        a2_scale = torch.ones(self.n_hidden + 1, self.n_hidden, dtype=torch.float64, device=device)
        a2_dropout = torch.tensor(1.0, dtype=torch.float64, device=device)
        a3_mean = torch.zeros(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
        a3_scale = torch.ones(self.n_hidden + 1, self.n_hidden, dtype=torch.float64, device=device)
        a3_dropout = torch.tensor(1.0, dtype=torch.float64, device=device)
        a4_mean = torch.zeros(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
        a4_scale = torch.ones(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
        # Mark batched calculations to be conditionally independent given parameters using `plate`
        with pyro.plate('data', size=n_images):
            # Sample first hidden layer
            h1 = pyro.sample('h1', bnn.HiddenLayer(images, a1_mean, a1_dropout * a1_scale, 
                                                   non_linearity=nnf.leaky_relu,
                                                   KL_factor=kl_factor))
            # Sample second hidden layer
            h2 = pyro.sample('h2', bnn.HiddenLayer(h1, a2_mean, a2_dropout * a2_scale,
                                                   non_linearity=nnf.leaky_relu,
                                                   KL_factor=kl_factor))
            # Sample third hidden layer
            h3 = pyro.sample('h3', bnn.HiddenLayer(h2, a3_mean, a3_dropout * a3_scale,
                                                   non_linearity=nnf.leaky_relu,
                                                   KL_factor=kl_factor))
            # Sample output logits
            logits = pyro.sample('logits', bnn.HiddenLayer(h3, a4_mean, a4_scale,
                                                           non_linearity=lambda x: nnf.log_softmax(x, dim=-1),
                                                           KL_factor=kl_factor,
                                                           include_hidden_bias=False))
            # One-hot encode labels
            labels = nnf.one_hot(labels) if labels is not None else None
            # Condition on observed labels, so it calculates the log-likehood loss when training using VI
            return pyro.sample('label', dist.OneHotCategorical(logits=logits), obs=labels) 
    
    def guide(self, images, labels=None, kl_factor=1.0):
        images = images.to(device).view(-1, 128*128)
        n_images = images.size(0)
        #print("guide: ", kl_factor)
        # kl_factor = torch.tensor(kl_factor).to(device)
        # Set-up parameters to be optimized to approximate the true posterior
        # Mean parameters are randomly initialized to small values around 0, and scale parameters
        # are initialized to be 0.1 to be closer to the expected posterior value which we assume is stronger than
        # the prior scale of 1.
        # Scale parameters must be positive, so we constraint them to be larger than some epsilon value (0.01).
        # Variational dropout are initialized as in the prior model, and constrained to be between 0.1 and 1 (so dropout
        # rate is between 0.1 and 0.5) as suggested in the local reparametrization paper
       # print(0.01 * torch.randn(128*128, self.n_hidden, dtype=torch.float64, device=device))
        a1_mean = pyro.param('a1_mean', 0.01 * torch.randn(128*128, self.n_hidden, device=device))
        a1_scale = pyro.param('a1_scale', 0.1 * torch.ones(128*128, self.n_hidden, device=device),
                              constraint=constraints.greater_than(0.01))
        a1_dropout = pyro.param('a1_dropout', torch.tensor(0.25, device=device),
                                constraint=constraints.interval(0.1, 1.0))
        a2_mean = pyro.param('a2_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden, device=device))
        a2_scale = pyro.param('a2_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_hidden, device=device),
                              constraint=constraints.greater_than(0.01)) 
        a2_dropout = pyro.param('a2_dropout', torch.tensor(1.0, device=device),
                                constraint=constraints.interval(0.1, 1.0))
        a3_mean = pyro.param('a3_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden, device=device))
        a3_scale = pyro.param('a3_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_hidden, device=device),
                              constraint=constraints.greater_than(0.01))
        a3_dropout = pyro.param('a3_dropout', torch.tensor(1.0, device=device),
                                constraint=constraints.interval(0.1, 1.0))
        a4_mean = pyro.param('a4_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_classes, device=device))
        a4_scale = pyro.param('a4_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_classes, device=device), constraint=constraints.greater_than(0.01))
        # Sample latent values using the variational parameters that are set-up above.
        # Notice how there is no conditioning on labels in the guide!
        with pyro.plate('data', size=n_images):
            h1 = pyro.sample('h1', bnn.HiddenLayer(images, a1_mean, a1_dropout * a1_scale, 
                                                   non_linearity=nnf.leaky_relu,
                                                   KL_factor=kl_factor))
            h2 = pyro.sample('h2', bnn.HiddenLayer(h1, a2_mean, a2_dropout * a2_scale,
                                                   non_linearity=nnf.leaky_relu,
                                                   KL_factor=kl_factor))
            h3 = pyro.sample('h3', bnn.HiddenLayer(h2, a3_mean, a3_dropout * a3_scale,
                                                   non_linearity=nnf.leaky_relu,
                                                   KL_factor=kl_factor))
            logits = pyro.sample('logits', bnn.HiddenLayer(h3, a4_mean, a4_scale,
                                                           non_linearity=lambda x: nnf.log_softmax(x, dim=-1),
                                                           KL_factor=kl_factor,
                                                           include_hidden_bias=False))
    
    def infer_parameters(self, loader, lr=0.01, momentum=0.9,
                         num_epochs=30):
        #optim = pyroopt.SGD({'lr': lr, 'momentum': momentum, 'nesterov': True})
        optim = pyroopt.Adam({'lr': lr})
        elbo = TraceMeanField_ELBO()
        svi = SVI(self.model, self.guide, optim, elbo)
        kl_factor = torch.tensor(loader.batch_size / len(loader.dataset)).to(device)
       # print("infer: ", kl_factor)
        for i in range(num_epochs):
            total_loss = 0.0 
            total = 0.0
            correct = 0.0
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                loss = svi.step(images, labels, kl_factor=kl_factor)
                #print(loss)
                pred = self.forward(images, n_samples=1).mean(0) 
                total_loss += loss / len(loader.dataset)
                total += labels.size(0)
                correct += (pred.argmax(-1) == labels).sum().item()
                param_store = pyro.get_param_store()
            print(f"[Epoch {i + 1}] loss: {total_loss:.5E} accuracy: {correct / total * 100:.5f}")

    def forward(self, images, n_samples=10):
        res = []
        for i in range(n_samples):
            t = poutine.trace(self.guide).get_trace(images)
            print(i, len(t))
            res.append(t.nodes['logits']['value'])

        return torch.stack(res, dim=0) 
bayesnn = BNN(n_hidden=2048, n_classes=3).double().to(device)
pyro.get_param_store().load("last_attempt.dat")
uncertain_images = []
for image, _ in test_loader:
    n_samples = 30
    preds = bayesnn.forward(image.to(torch.device("cpu")).view(-1, 128*128), n_samples=n_samples).argmax(-1).squeeze()
    pred_sum = [(i, c) for i, c in enumerate(preds.bincount(minlength=10).tolist()) if c > 0]
    if len(pred_sum) > 1:
        uncertain_images.append((image, "\n".join(f"{i}: {c / n_samples:.2f}" for i, c in pred_sum)))
    if len(uncertain_images) >= 64:
        break

Hi @Eastonboy99,
Here are some suggestions:

  1. set pyro.enable_validation(True) to ensure you don’t have shape errors. Sometimes shape errors can lead to OOM errors due to broadcasting e.g. torch.ones(1000000, 1) + torch.ones(1000000).
  2. Decorate your .forward() with @torch.no_grad().
  3. Try batching in forward. At the most extreme you could use two for loops, one over images and one over samples.
  4. Instead of accumulating torch.stack(res, dim=0), create a result of bincounts via .scatter_add_(). This would avoid creating an O(num_images * num_samples) intermediate object.

Good luck!

Awesome! Thank you so much!!

I’m a little confused as to where the .scatter_add_() would go. Would that be in the last for loop or in the return statement in the forward function?

You could add cheaply using something like this:

def forward(self, images, n_samples=10):
    counts = torch.zeros(len(images), self.n_classes)
    for i in range(n_samples):
        trace = poutine.trace(self.guide).get_trace(images)
        counts += trace.nodes["logits"]["value"]
    return counts

EDIT updated to remove .scatter_add_().

Hi @Eastonboy99, sorry, I just realized you don’t need .scatter_add_() since you’re using OneHotCategorical; a simple += works fine. I’ve updated the above code snippet.