First thank you so much for taking the time to answer my questions! Thatâs really helpful!
Of course I can share the full code, I just didnât want to bore you with what I thought was not relevant her
Below is the ResNet model. I did not implement the AutoGuide as I thought it required a full Pyro model. The ResNet is purely made of deterministic layers and I only wanted to retrain the FC layers as âprobabilistic layersâ.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResBlock(nn.Module):
def __init__(self, num_layers, in_channels, out_channels, identity_downsample=None, stride=1):
# the assert statement is used to continue the execute if the given condition evaluates to True.
# If the assert condition evaluates to False, then it raises the AssertionError exception with the specified error message.
assert num_layers in [18, 34, 50, 101, 152], "should be a a valid architecture"
super(ResBlock, self).__init__()
# In the ResNet > 34, the number of channels in the 3rd convolution is 4 times the number
# of channels of the second one.
self.num_layers = num_layers
if self.num_layers > 34:
self.expansion = 4
else:
self.expansion = 1
# ResNet50, 101, and 152 include additional layer of 1x1 kernels
# See the forward function
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(out_channels)
if self.num_layers > 34:
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
else:
# for ResNet18 and 34, connect input directly to (3x3) kernel (skip first (1x1))
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
# RELU function
self.relu = nn.ReLU()
# conv layer that we will do to the identity mapping so it has the same shape as the other layers
self.identity_downsample = identity_downsample
def forward(self, x):
identity = x
if self.num_layers > 34:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
# We use the layer if we need to change the shape
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
class ResNet(nn.Module):
def __init__(self, num_layers, image_channels, num_classes, block):
assert num_layers in [18, 34, 50, 101, 152], f'ResNet{num_layers}: Unknown architecture! Number of layers has to be 18, 34, 50, 101, or 152 '
super(ResNet, self).__init__()
if num_layers < 50:
self.expansion = 1
else:
self.expansion = 4
if num_layers == 18:
layers = [2, 2, 2, 2]
elif num_layers == 34 or num_layers == 50:
layers = [3, 4, 6, 3]
elif num_layers == 101:
layers = [3, 4, 23, 3]
else:
layers = [3, 8, 36, 3]
# For ALL ResNet, the first layer is a convolution of stride 7 with 64 output channels
self.in_channels = 64
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
# This first layer is followed by a maxpool of stride 2 and kernel of 3
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# ResNetLayers
self.layer1 = self.make_layers(num_layers, block, layers[0], intermediate_channels=64, stride=1)
self.layer2 = self.make_layers(num_layers, block, layers[1], intermediate_channels=128, stride=2)
self.layer3 = self.make_layers(num_layers, block, layers[2], intermediate_channels=256, stride=2)
self.layer4 = self.make_layers(num_layers, block, layers[3], intermediate_channels=512, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(512 * self.expansion, 1000)
self.fc2 = nn.Linear(1000, 512)
self.fc_out = nn.Linear(512, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc_out(x)
return x
def make_layers(self, num_layers, block, num_residual_blocks, intermediate_channels, stride):
layers = []
# In the paper the identity downsample is a conv layer
identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, intermediate_channels*self.expansion, kernel_size=1, stride=stride),
nn.BatchNorm2d(intermediate_channels*self.expansion))
# This is the layer that changes the number of channels
# If we look at the first ResNet layer, it's going to change 64 to 256
# The downsample is ONLY for the first block, then it's just a normal identity
# mapping
layers.append(block(num_layers, self.in_channels, intermediate_channels, identity_downsample, stride))
# Then we need to update in_channels!
self.in_channels = intermediate_channels * self.expansion # 256
for i in range(num_residual_blocks - 1):
layers.append(block(num_layers, self.in_channels, intermediate_channels)) # At the end of the first block
# we will have 256 channels (in_channels here) and out_channels will be 64.
return nn.Sequential(*layers)
# Libraries
import argparse
import torch
import glob
import numpy as np
import random
import torch.optim as optim
import torch.nn as nn
import albumentations as A
import pyro
from pyro.distributions import Normal, Categorical
from PIL import Image
from torch.utils.data import DataLoader
from models.ResNet import ResNet
from models.ResNet import ResBlock
from utils.loader import ImgLoader
class HybridNN():
def __init__(self, data_path, model_path, save_path, num_epochs, batch_size):
self.data_path = data_path
self.model_path = model_path
self.save_path = save_path
self.num_epochs = num_epochs
self.batch_size = batch_size
# Transform
self.transform = self.initTransform()
# Cuda
self.is_cuda = torch.cuda.is_available()
self.device = torch.device("cpu" if self.is_cuda else "cpu") # Place tensors on GPU later
# Related to model
self.expansion = 1
self.num_classes = 2
self.NN = self.initModel()
self.optimizer = self.initOptimizer()
self.criterion = nn.CrossEntropyLoss()
# For predictions
self.num_samples = 10
# Metrics
self.METRICS_SIZE = 2
self.LABELS_NDX = 0
self.PREDICTED_NDX = 1
def initModel(self):
model = ResNet(num_layers=18, num_classes=2, image_channels=3, block=ResBlock)
model.load_state_dict(torch.load(self.model_path), strict=False)
model.to(self.device)
# Freeze the layers
for param in model.parameters():
param.requires_grad = False
# Reinitialize FC layers so they are trainable
model.fc1 = nn.Linear(512, 1000)
model.fc2 = nn.Linear(1000, 512)
model.fc_out = nn.Linear(512, self.num_classes)
return model
def initOptimizer(self):
# pyro.optim functions are wrappers around torch.optim functions
return pyro.optim.Adam({"lr": 0.01})
def model(self, x_data, y_data):
"""
Model function defines how the output data is generated
"""
log_softmax = nn.LogSoftmax(dim=1)
fc1w_prior = Normal(loc=torch.zeros_like(self.NN.fc1.weight), scale=torch.ones_like(self.NN.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(self.NN.fc1.bias), scale=torch.ones_like(self.NN.fc1.bias))
fc2w_prior = Normal(loc=torch.zeros_like(self.NN.fc2.weight), scale=torch.ones_like(self.NN.fc2.weight))
fc2b_prior = Normal(loc=torch.zeros_like(self.NN.fc2.bias), scale=torch.ones_like(self.NN.fc2.bias))
fc_outw_prior = Normal(loc=torch.zeros_like(self.NN.fc_out.weight), scale=torch.ones_like(self.NN.fc_out.weight))
fc_outb_prior = Normal(loc=torch.zeros_like(self.NN.fc_out.bias), scale=torch.ones_like(self.NN.fc_out.bias))
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,
'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
'fc_out.weight': fc_outw_prior, 'fc_out.bias': fc_outb_prior}
# the function pyro.random_module() converts parameters of our neural network (weights and biases)
# into random variables that have the initial (prior) probability distribution given by fc_priors
lifted_module = pyro.random_module("module", self.NN, priors)
# sample a regressor (which also samples w and b)
lifted_reg_model = lifted_module()
lhat = log_softmax(lifted_reg_model(x_data))
#with pyro.plate("data", len(x_data)):
pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
def guide(self, x_data, y_data):
""""
The guide is the approximation of the posterior distribution: the variational distribution
http://pyro.ai/examples/intro_long.html#Background:-variational-inference
See the video: https://youtu.be/DYRK0-_K2UU
"""
softplus = torch.nn.Softplus()
# First layer weight distribution priors
fc1w_mu = torch.randn_like(self.NN.fc1.weight)
fc1w_sigma = torch.randn_like(self.NN.fc1.weight)
fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
# First layer bias distribution priors
fc1b_mu = torch.randn_like(self.NN.fc1.bias)
fc1b_sigma = torch.randn_like(self.NN.fc1.bias)
fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
# Second layer weight distribution priors
fc2w_mu = torch.randn_like(self.NN.fc2.weight)
fc2w_sigma = torch.randn_like(self.NN.fc2.weight)
fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma))
fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)
# Second layer bias distribution priors
fc2b_mu = torch.randn_like(self.NN.fc2.bias)
fc2b_sigma = torch.randn_like(self.NN.fc2.bias)
fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma))
fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)
# Third layer weight distribution priors
fc_outw_mu = torch.randn_like(self.NN.fc_out.weight)
fc_outw_sigma = torch.randn_like(self.NN.fc_out.weight)
fc_outw_mu_param = pyro.param("fc_outw_mu", fc_outw_mu)
fc_outw_sigma_param = softplus(pyro.param("fc_outw_sigma", fc_outw_sigma))
fc_outw_prior = Normal(loc=fc_outw_mu_param, scale=fc_outw_sigma_param)
# Third layer bias distribution priors
fc_outb_mu = torch.randn_like(self.NN.fc_out.bias)
fc_outb_sigma = torch.randn_like(self.NN.fc_out.bias)
fc_outb_mu_param = pyro.param("fc_outb_mu", fc_outb_mu)
fc_outb_sigma_param = softplus(pyro.param("fc_outb_sigma", fc_outb_sigma))
fc_outb_prior = Normal(loc=fc_outb_mu_param, scale=fc_outb_sigma_param)
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,
'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
'fc_out.weight': fc_outw_prior, 'fc_out.bias': fc_outb_prior}
lifted_module = pyro.random_module("module", self.NN, priors)
return lifted_module()
def initTransform(self):
transform = A.Compose(
[
A.SmallestMaxSize(max_size=260),
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
A.RandomCrop(height=224, width=224),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
return transform
def predict(self, x):
"""
for guide() after optimization iterations, the distribution given by the parameter values
approximate the true posterior and so we can use it for predictions.
for each prediction, we are sampling a new set of weights and parameters num_samples times.
This effectively means that we are sampling a new neural network 10 times for making one prediction.
In our case, to make a prediction, we are averaging final layer output values of the 10 sampled nets for
the given input and taking the max activation value as the predicted digit.
"""
sampled_models = [self.guide(None, None) for _ in range(self.num_samples)]
yhats = [model(x).data for model in sampled_models]
mean = torch.mean(torch.stack(yhats), 0)
return np.argmax(mean.numpy(), axis=1)
def trainingLoop(self):
# Specify the function to optimize
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(self.model, self.guide, self.optimizer, loss=elbo)
# Prepare the dataset
list_images = glob.glob(self.data_path + "/*.jpg")
# Make a training and testing set
train_images = random.sample(list_images, int(len(list_images) * 0.8))
test_images = [item for item in list_images if item not in train_images]
trainset = ImgLoader(train_images, transform=self.transform)
testset = ImgLoader(test_images, transform=self.transform)
trainLoader = DataLoader(trainset,
batch_size = self.batch_size,
shuffle=True,
pin_memory=self.is_cuda)
# Train the network
for epoch in range(self.num_epochs): # loop over the dataset multiple times
training_loss = 0.0
nb_train_batch = 0
metrics_trn = torch.zeros(self.METRICS_SIZE, len(trainLoader.dataset), device=self.device)
#metrics_val = torch.zeros(self.METRICS_SIZE, len(testLoader.dataset), device=self.device)
for batch_ndx, tup_ndx in enumerate(trainLoader, 0):
start_ndx = batch_ndx * self.batch_size
end_ndx = start_ndx + self.batch_size
# get the inputs; data is a list of [inputs, labels]
imgs, labels = tup_ndx
imgs = imgs.permute(0,3,1,2)
# Imgs and labels to GPU if possible
imgs = imgs.to(self.device)
labels = labels.to(self.device)
# forward + backward + optimize
loss = svi.step(imgs, labels)
# Compute the metrics
outputs = self.NN(imgs)
_, predicted = torch.max(outputs, dim=1)
metrics_trn[self.LABELS_NDX, start_ndx:end_ndx] = labels
metrics_trn[self.PREDICTED_NDX, start_ndx:end_ndx] = predicted
# Get accuracy
correct = torch.eq(metrics_trn[self.LABELS_NDX], metrics_trn[self.PREDICTED_NDX]).sum()
accuracy = correct / len(metrics_trn[self.LABELS_NDX])
training_loss += loss
nb_train_batch += 1
print("Epoch: {}, Training loss: {}, Training accuracy: {}".format(epoch, training_loss / nb_train_batch, accuracy))
print('Finished Training')
# Save the model
torch.save(self.model.state_dict(), self.save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_path",
help="Path to the folder containing the data (train and test set will be sampled)",
required=True,
type=str
)
parser.add_argument("--model_path",
help="Path to the model weights",
required=True,
type=str
)
parser.add_argument("--save_path",
help="Number of epochs to train the model on",
required=True,
type=str
)
parser.add_argument("--num_epochs",
help="Number of epochs to train the model on",
required=True,
type=int
)
cli_args = parser.parse_args()
trainingApp(cli_args.data_path, cli_args.model_path, cli_args.save_path, cli_args.num_epochs).trainingLoop()
I also computed the batch_shape at each fully connected layers and I get:
Guide shape
torch.Size([1000, 512]) #fc1 weights
torch.Size([1000]) # fc1 biases
torch.Size([512, 1000]) # fc2 weights
torch.Size([512]) # fc2 bias
torch.Size([2, 512]) # fc_out weights
torch.Size([2, 512]) # fc_out bias
Model shape
torch.Size([1000, 512])
torch.Size([1000])
torch.Size([512, 1000])
torch.Size([512])
torch.Size([2, 512])
torch.Size([2])