Hello there.
I am trying to finish a project for my college course and I am getting a bit stuck when it comes to using pyro with a gpu. The project is to take Chest Xray images of covid and pneumonia patients and use a Bayesian Neural Network to classify them. I was able to train the model but now when I try to determine which images the BNN is uncertain about, I run out of GPU ram and I am not entirely sure why. I am very new to this subject so any help is appreciated. The error seems to be occurring in the BNN forward function. The Code I have is found below:
# Import relevant packages
import os
import random
import torch
import torch.nn.functional as nnf
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
from torch.optim import SGD
from torch.distributions import constraints
import torchvision as torchv
import torchvision.transforms as torchvt
from torchvision.datasets.mnist import MNIST
from torch import nn
from pyro.infer import SVI, TraceMeanField_ELBO
import pyro
from pyro import poutine
import pyro.optim as pyroopt
import pyro.distributions as dist
import pyro.contrib.bnn as bnn
import matplotlib.pyplot as plt
import seaborn as sns
from torch.distributions.utils import lazy_property
import math
from PIL import Image
%matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class ChestXRayDataset(torch.utils.data.Dataset):
def __init__(self, image_dirs, transform):
def get_images(class_name):
images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
print(f'Found {len(images)} {class_name} examples')
return images
self.images = {}
self.class_names = ['normal', 'viral', 'covid']
for class_name in self.class_names:
self.images[class_name] = get_images(class_name)
self.image_dirs = image_dirs
self.transform = transform
def __len__(self):
return sum([len(self.images[class_name]) for class_name in self.class_names])
def __getitem__(self, index):
class_name = random.choice(self.class_names)
index = index % len(self.images[class_name])
image_name = self.images[class_name][index]
image_path = os.path.join(self.image_dirs[class_name], image_name)
image = Image.open(image_path).convert('RGB')
return self.transform(image), self.class_names.index(class_name)
train_transform = torchvt.Compose([
torchvt.Resize(size=(128, 128)),
torchvt.RandomHorizontalFlip(),
torchvt.Grayscale(num_output_channels=1),
torchvt.ToTensor(),
torchvt.Normalize((0.5, ), (0.5, ))
])
test_transform = torchvt.Compose([
torchvt.Resize(size=(128, 128)),
torchvt.Grayscale(num_output_channels=1),
torchvt.ToTensor(),
torchvt.Normalize((0.5, ), (0.5, ))
])
train_dirs = {
'normal': f'{root_dir}/normal',
'viral': f'{root_dir}/viral',
'covid': f'{root_dir}/covid'
}
train_dataset = ChestXRayDataset(train_dirs, train_transform)
test_dirs = {
'normal': f'{root_dir}/test/normal',
'viral': f'{root_dir}/test/viral',
'covid': f'{root_dir}/test/covid'
}
test_dataset = ChestXRayDataset(test_dirs, test_transform)
batch_size = 6
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
class BNN(nn.Module):
def __init__(self, n_hidden=1024, n_classes=10):
super(BNN, self).__init__()
self.n_hidden = n_hidden
self.n_classes = n_classes
def model(self, images, labels=None, kl_factor=1.0):
images = images.to(device).view(-1, 128*128)
n_images = images.size(0)
# Set-up parameters for the distribution of weights for each layer `a<n>`
a1_mean = torch.zeros(128*128, self.n_hidden, dtype=torch.float64, device=device)
a1_scale = torch.ones(128*128, self.n_hidden, dtype=torch.float64, device=device)
a1_dropout = torch.tensor(0.25, dtype=torch.float64, device=device)
a2_mean = torch.zeros(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
a2_scale = torch.ones(self.n_hidden + 1, self.n_hidden, dtype=torch.float64, device=device)
a2_dropout = torch.tensor(1.0, dtype=torch.float64, device=device)
a3_mean = torch.zeros(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
a3_scale = torch.ones(self.n_hidden + 1, self.n_hidden, dtype=torch.float64, device=device)
a3_dropout = torch.tensor(1.0, dtype=torch.float64, device=device)
a4_mean = torch.zeros(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
a4_scale = torch.ones(self.n_hidden + 1, self.n_classes, dtype=torch.float64, device=device)
# Mark batched calculations to be conditionally independent given parameters using `plate`
with pyro.plate('data', size=n_images):
# Sample first hidden layer
h1 = pyro.sample('h1', bnn.HiddenLayer(images, a1_mean, a1_dropout * a1_scale,
non_linearity=nnf.leaky_relu,
KL_factor=kl_factor))
# Sample second hidden layer
h2 = pyro.sample('h2', bnn.HiddenLayer(h1, a2_mean, a2_dropout * a2_scale,
non_linearity=nnf.leaky_relu,
KL_factor=kl_factor))
# Sample third hidden layer
h3 = pyro.sample('h3', bnn.HiddenLayer(h2, a3_mean, a3_dropout * a3_scale,
non_linearity=nnf.leaky_relu,
KL_factor=kl_factor))
# Sample output logits
logits = pyro.sample('logits', bnn.HiddenLayer(h3, a4_mean, a4_scale,
non_linearity=lambda x: nnf.log_softmax(x, dim=-1),
KL_factor=kl_factor,
include_hidden_bias=False))
# One-hot encode labels
labels = nnf.one_hot(labels) if labels is not None else None
# Condition on observed labels, so it calculates the log-likehood loss when training using VI
return pyro.sample('label', dist.OneHotCategorical(logits=logits), obs=labels)
def guide(self, images, labels=None, kl_factor=1.0):
images = images.to(device).view(-1, 128*128)
n_images = images.size(0)
#print("guide: ", kl_factor)
# kl_factor = torch.tensor(kl_factor).to(device)
# Set-up parameters to be optimized to approximate the true posterior
# Mean parameters are randomly initialized to small values around 0, and scale parameters
# are initialized to be 0.1 to be closer to the expected posterior value which we assume is stronger than
# the prior scale of 1.
# Scale parameters must be positive, so we constraint them to be larger than some epsilon value (0.01).
# Variational dropout are initialized as in the prior model, and constrained to be between 0.1 and 1 (so dropout
# rate is between 0.1 and 0.5) as suggested in the local reparametrization paper
# print(0.01 * torch.randn(128*128, self.n_hidden, dtype=torch.float64, device=device))
a1_mean = pyro.param('a1_mean', 0.01 * torch.randn(128*128, self.n_hidden, device=device))
a1_scale = pyro.param('a1_scale', 0.1 * torch.ones(128*128, self.n_hidden, device=device),
constraint=constraints.greater_than(0.01))
a1_dropout = pyro.param('a1_dropout', torch.tensor(0.25, device=device),
constraint=constraints.interval(0.1, 1.0))
a2_mean = pyro.param('a2_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden, device=device))
a2_scale = pyro.param('a2_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_hidden, device=device),
constraint=constraints.greater_than(0.01))
a2_dropout = pyro.param('a2_dropout', torch.tensor(1.0, device=device),
constraint=constraints.interval(0.1, 1.0))
a3_mean = pyro.param('a3_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden, device=device))
a3_scale = pyro.param('a3_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_hidden, device=device),
constraint=constraints.greater_than(0.01))
a3_dropout = pyro.param('a3_dropout', torch.tensor(1.0, device=device),
constraint=constraints.interval(0.1, 1.0))
a4_mean = pyro.param('a4_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_classes, device=device))
a4_scale = pyro.param('a4_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_classes, device=device), constraint=constraints.greater_than(0.01))
# Sample latent values using the variational parameters that are set-up above.
# Notice how there is no conditioning on labels in the guide!
with pyro.plate('data', size=n_images):
h1 = pyro.sample('h1', bnn.HiddenLayer(images, a1_mean, a1_dropout * a1_scale,
non_linearity=nnf.leaky_relu,
KL_factor=kl_factor))
h2 = pyro.sample('h2', bnn.HiddenLayer(h1, a2_mean, a2_dropout * a2_scale,
non_linearity=nnf.leaky_relu,
KL_factor=kl_factor))
h3 = pyro.sample('h3', bnn.HiddenLayer(h2, a3_mean, a3_dropout * a3_scale,
non_linearity=nnf.leaky_relu,
KL_factor=kl_factor))
logits = pyro.sample('logits', bnn.HiddenLayer(h3, a4_mean, a4_scale,
non_linearity=lambda x: nnf.log_softmax(x, dim=-1),
KL_factor=kl_factor,
include_hidden_bias=False))
def infer_parameters(self, loader, lr=0.01, momentum=0.9,
num_epochs=30):
#optim = pyroopt.SGD({'lr': lr, 'momentum': momentum, 'nesterov': True})
optim = pyroopt.Adam({'lr': lr})
elbo = TraceMeanField_ELBO()
svi = SVI(self.model, self.guide, optim, elbo)
kl_factor = torch.tensor(loader.batch_size / len(loader.dataset)).to(device)
# print("infer: ", kl_factor)
for i in range(num_epochs):
total_loss = 0.0
total = 0.0
correct = 0.0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
loss = svi.step(images, labels, kl_factor=kl_factor)
#print(loss)
pred = self.forward(images, n_samples=1).mean(0)
total_loss += loss / len(loader.dataset)
total += labels.size(0)
correct += (pred.argmax(-1) == labels).sum().item()
param_store = pyro.get_param_store()
print(f"[Epoch {i + 1}] loss: {total_loss:.5E} accuracy: {correct / total * 100:.5f}")
def forward(self, images, n_samples=10):
res = []
for i in range(n_samples):
t = poutine.trace(self.guide).get_trace(images)
print(i, len(t))
res.append(t.nodes['logits']['value'])
return torch.stack(res, dim=0)
bayesnn = BNN(n_hidden=2048, n_classes=3).double().to(device)
pyro.get_param_store().load("last_attempt.dat")
uncertain_images = []
for image, _ in test_loader:
n_samples = 30
preds = bayesnn.forward(image.to(torch.device("cpu")).view(-1, 128*128), n_samples=n_samples).argmax(-1).squeeze()
pred_sum = [(i, c) for i, c in enumerate(preds.bincount(minlength=10).tolist()) if c > 0]
if len(pred_sum) > 1:
uncertain_images.append((image, "\n".join(f"{i}: {c / n_samples:.2f}" for i, c in pred_sum)))
if len(uncertain_images) >= 64:
break