Bayesian ResNet: Invalid log_prob shape

Hi,

I am aware that the topic has been mentioned more than once and I am certain that my problem is a trivial one. Somehow I cannot get my head around it and I would extremely glad if I could get a some help and some more detailed explanations.

I am currently trying to convert the deterministic fully connected layers of a pre-trained ResNet model to probabilistic layers. I have followed some of the tutorial found online but I am stuck with understanding dimensionality with pyro as I get the error:

ValueError: at site "module$$$fc1.weight", invalid log_prob shape
  Expected [], actual [1000, 512]

Here is the code for initializing the ResNet and freeze the convolutional layers.

        
    def initModel(self):
        model = ResNet(num_layers=18, num_classes=2, image_channels=3, block=ResBlock)
        model.load_state_dict(torch.load(self.model_path), strict=False)
        model.to(self.device)

        # Freeze the layers
        for param in model.parameters():
            param.requires_grad = False

        # Reinitialize FC layers so they are trainable
        model.fc1 = nn.Linear(512, 1000)
        model.fc2 = nn.Linear(1000, 512)
        model.fc_out = nn.Linear(512, self.num_classes)

        return model

The model for the

    def model(self, x_data, y_data):


        log_softmax = nn.LogSoftmax(dim=1)

        fc1w_prior = Normal(loc=torch.zeros_like(self.NN.fc1.weight), scale=torch.ones_like(self.NN.fc1.weight))
        fc1b_prior = Normal(loc=torch.zeros_like(self.NN.fc1.bias), scale=torch.ones_like(self.NN.fc1.bias))
        
        fc2w_prior = Normal(loc=torch.zeros_like(self.NN.fc2.weight), scale=torch.ones_like(self.NN.fc2.weight))
        fc2b_prior = Normal(loc=torch.zeros_like(self.NN.fc2.bias), scale=torch.ones_like(self.NN.fc2.bias))

        fc_outw_prior = Normal(loc=torch.zeros_like(self.NN.fc_out.weight), scale=torch.ones_like(self.NN.fc_out.weight))
        fc_outb_prior = Normal(loc=torch.zeros_like(self.NN.fc_out.bias), scale=torch.ones_like(self.NN.fc_out.bias))
        
        priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  
                    'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
                    'fc_out.weight': fc_outw_prior, 'fc_out.bias': fc_outb_prior}

        # the function pyro.random_module() converts parameters of our neural network (weights and biases) 
        # into random variables that have the initial (prior) probability distribution given by fc_priors
        lifted_module = pyro.random_module("module", self.NN, priors)

        # sample a regressor (which also samples w and b)
        lifted_reg_model = lifted_module()
        
        lhat = log_softmax(lifted_reg_model(x_data))
        
        pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

And the guide:

def guide(self, x_data, y_data):
        """"
        The guide is the approximation of the posterior distribution: the variational distribution
        http://pyro.ai/examples/intro_long.html#Background:-variational-inference
        See the video: https://youtu.be/DYRK0-_K2UU
        """
        
        softplus = torch.nn.Softplus()

        # First layer weight distribution priors
        fc1w_mu = torch.randn_like(self.NN.fc1.weight)
        fc1w_sigma = torch.randn_like(self.NN.fc1.weight)
        fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
        fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
        fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)

        # First layer bias distribution priors
        fc1b_mu = torch.randn_like(self.NN.fc1.bias)
        fc1b_sigma = torch.randn_like(self.NN.fc1.bias)
        fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
        fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
        fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)

        # Second layer weight distribution priors
        fc2w_mu = torch.randn_like(self.NN.fc2.weight)
        fc2w_sigma = torch.randn_like(self.NN.fc2.weight)
        fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
        fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma))
        fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)

        # Second layer bias distribution priors
        fc2b_mu = torch.randn_like(self.NN.fc2.bias)
        fc2b_sigma = torch.randn_like(self.NN.fc2.bias)
        fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
        fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma))
        fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)

        # Third layer weight distribution priors
        fc_outw_mu = torch.randn_like(self.NN.fc_out.weight)
        fc_outw_sigma = torch.randn_like(self.NN.fc_out.weight)
        fc_outw_mu_param = pyro.param("fc_outw_mu", fc_outw_mu)
        fc_outw_sigma_param = softplus(pyro.param("fc_outw_sigma", fc_outw_sigma))
        fc_outw_prior = Normal(loc=fc_outw_mu_param, scale=fc_outw_sigma_param)

        # Third layer bias distribution priors
        fc_outb_mu = torch.randn_like(self.NN.fc_out.bias)
        fc_outb_sigma = torch.randn_like(self.NN.fc_out.bias)
        fc_outb_mu_param = pyro.param("fc_outb_mu", fc_outb_mu)
        fc_outb_sigma_param = softplus(pyro.param("fc_outb_sigma", fc_outb_sigma))
        fc_outb_prior = Normal(loc=fc_outb_mu_param, scale=fc_outb_sigma_param)
        
        priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 
                    'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
                    'fc_out.weight': fc_outw_prior, 'fc_out.bias': fc_outb_prior}

        lifted_module = pyro.random_module("module", self.NN, priors)
        
        return lifted_module()

My understanding is that in “model” the priors for biases and weights should have a dimension of output size of fully connected layer * batch size. For instance fc1wprior:

fc1w_prior = Normal(loc=torch.zeros_like(self.NN.fc1.weight), scale=torch.ones_like(self.NN.fc1.weight))

should have a dimension of 1000 * batch_size. Nevertheless, the shape of the tensor is currently 512 (input) * 1000 (output), is that correct?

Then would the correct code be something like:

fc1w_prior = Normal(loc=torch.zeros(1000), scale=torch.ones(1000)).expand(batch_size) ?

Also, if I understand correctly, in my case, pyro.sample should be of dimension 2 (number of output neurons) x batch_size, is it correct?

In some posts addressing the same error I saw the mention of ìndependence(1) or (2) and to_event(). However, I fail to see how this is relevant in my case. As I understand to_event() is used to assume dependence between dimensions whereas independence() is used to declare independence between dimensions.

I hope I clearly explained my problems and the questions resulting from it!

Looking forward to your help!

I think the prior should be independent of the batch_shape and return a tensor with the same shape as the weights in the deterministic network. See this example for a bayesian version of an MLP with numpyro.

That’s what I understood when I first specified the model and guide but then I get the error below.

ValueError: at site "module$$$fc1.weight", invalid log_prob shape
  Expected [], actual [1000, 512]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Reading Tensor shapes in Pyro — Pyro Tutorials 1.8.0 documentation I get that the log_prob tensor should have a shape of batch_shape. So why is the expected batch_shape [] instead of [1000, 512]?

I’m not sure if the interpretation is that the log_prob tensor should have the shape batch_shape but either that this is the behaviour. F.e. Normal(0,1).log_prob(torch.ones(512)).shape==(512,).

Could you share your full code?
Have you tried the AutoGuide? If your model works in combination with the AutoGuide then you know that your model specification is correct.

First thank you so much for taking the time to answer my questions! That’s really helpful!

Of course I can share the full code, I just didn’t want to bore you with what I thought was not relevant her :slight_smile:

Below is the ResNet model. I did not implement the AutoGuide as I thought it required a full Pyro model. The ResNet is purely made of deterministic layers and I only wanted to retrain the FC layers as “probabilistic layers”.

import torch 
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):

    def __init__(self, num_layers, in_channels, out_channels, identity_downsample=None, stride=1):
        # the assert statement is used to continue the execute if the given condition evaluates to True. 
        # If the assert condition evaluates to False, then it raises the AssertionError exception with the specified error message.
        assert num_layers in [18, 34, 50, 101, 152], "should be a a valid architecture"
        super(ResBlock, self).__init__()
        # In the ResNet > 34, the number of channels in the 3rd convolution is 4 times the number
        # of channels of the second one.
        self.num_layers = num_layers
        if self.num_layers > 34:
            self.expansion = 4
        else:
            self.expansion = 1

        # ResNet50, 101, and 152 include additional layer of 1x1 kernels
        # See the forward function
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)

        if self.num_layers > 34:
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        else:
            # for ResNet18 and 34, connect input directly to (3x3) kernel (skip first (1x1))
            self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)

        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        # RELU function
        self.relu = nn.ReLU()

        # conv layer that we will do to the identity mapping so it has the same shape as the other layers
        self.identity_downsample = identity_downsample

    def forward(self, x):
        identity = x
        if self.num_layers > 34:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)

        # We use the layer if we need to change the shape
        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
        x += identity
        x = self.relu(x)
        return x

class ResNet(nn.Module):
    def __init__(self, num_layers, image_channels, num_classes, block):
        assert num_layers in [18, 34, 50, 101, 152], f'ResNet{num_layers}: Unknown architecture! Number of layers has to be 18, 34, 50, 101, or 152 '
        super(ResNet, self).__init__()
        if num_layers < 50:
            self.expansion = 1
        else:
            self.expansion = 4
        if num_layers == 18:
            layers = [2, 2, 2, 2]
        elif num_layers == 34 or num_layers == 50:
            layers = [3, 4, 6, 3]
        elif num_layers == 101:
            layers = [3, 4, 23, 3]
        else:
            layers = [3, 8, 36, 3]

        # For ALL ResNet, the first layer is a convolution of stride 7 with 64 output channels
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()

        # This first layer is followed by a maxpool of stride 2 and kernel of 3
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNetLayers
        self.layer1 = self.make_layers(num_layers, block, layers[0], intermediate_channels=64, stride=1)
        self.layer2 = self.make_layers(num_layers, block, layers[1], intermediate_channels=128, stride=2)
        self.layer3 = self.make_layers(num_layers, block, layers[2], intermediate_channels=256, stride=2)
        self.layer4 = self.make_layers(num_layers, block, layers[3], intermediate_channels=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512 * self.expansion, 1000)
        self.fc2 = nn.Linear(1000, 512)
        self.fc_out = nn.Linear(512, num_classes)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc_out(x)
        return x

    def make_layers(self, num_layers, block, num_residual_blocks, intermediate_channels, stride):

        layers = []

        # In the paper the identity downsample is a conv layer
        identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, intermediate_channels*self.expansion, kernel_size=1, stride=stride),
                                            nn.BatchNorm2d(intermediate_channels*self.expansion))

        # This is the layer that changes the number of channels
        # If we look at the first ResNet layer, it's going to change 64 to 256
        # The downsample is ONLY for the first block, then it's just a normal identity 
        # mapping
        layers.append(block(num_layers, self.in_channels, intermediate_channels, identity_downsample, stride))

        # Then we need to update in_channels!
        self.in_channels = intermediate_channels * self.expansion # 256
        for i in range(num_residual_blocks - 1):
            layers.append(block(num_layers, self.in_channels, intermediate_channels)) # At the end of the first block

            # we will have 256 channels (in_channels here) and out_channels will be 64. 
        return nn.Sequential(*layers)
# Libraries
import argparse
import torch
import glob
import numpy as np
import random
import torch.optim as optim
import torch.nn as nn
import albumentations as A

import pyro
from pyro.distributions import Normal, Categorical

from PIL import Image
from torch.utils.data import DataLoader
from models.ResNet import ResNet
from models.ResNet import ResBlock

from utils.loader import ImgLoader

class HybridNN():

    def __init__(self, data_path, model_path, save_path, num_epochs, batch_size):

        self.data_path = data_path
        self.model_path = model_path
        self.save_path = save_path
        self.num_epochs = num_epochs
        self.batch_size = batch_size

        # Transform
        self.transform = self.initTransform()

        # Cuda
        self.is_cuda = torch.cuda.is_available()
        self.device = torch.device("cpu" if self.is_cuda else "cpu") # Place tensors on GPU later

        # Related to model
        self.expansion = 1
        self.num_classes = 2
        self.NN = self.initModel()
        self.optimizer = self.initOptimizer()
        self.criterion = nn.CrossEntropyLoss()

        # For predictions
        self.num_samples = 10

        # Metrics
        self.METRICS_SIZE = 2
        self.LABELS_NDX = 0
        self.PREDICTED_NDX = 1
        
    def initModel(self):
        model = ResNet(num_layers=18, num_classes=2, image_channels=3, block=ResBlock)
        model.load_state_dict(torch.load(self.model_path), strict=False)
        model.to(self.device)

        # Freeze the layers
        for param in model.parameters():
            param.requires_grad = False

        # Reinitialize FC layers so they are trainable
        model.fc1 = nn.Linear(512, 1000)
        model.fc2 = nn.Linear(1000, 512)
        model.fc_out = nn.Linear(512, self.num_classes)

        return model

    def initOptimizer(self):

        # pyro.optim functions are wrappers around torch.optim functions
        return pyro.optim.Adam({"lr": 0.01})

    def model(self, x_data, y_data):
        """
        Model function defines how the output data is generated
        """

        log_softmax = nn.LogSoftmax(dim=1)

        fc1w_prior = Normal(loc=torch.zeros_like(self.NN.fc1.weight), scale=torch.ones_like(self.NN.fc1.weight))
        fc1b_prior = Normal(loc=torch.zeros_like(self.NN.fc1.bias), scale=torch.ones_like(self.NN.fc1.bias))
        
        fc2w_prior = Normal(loc=torch.zeros_like(self.NN.fc2.weight), scale=torch.ones_like(self.NN.fc2.weight))
        fc2b_prior = Normal(loc=torch.zeros_like(self.NN.fc2.bias), scale=torch.ones_like(self.NN.fc2.bias))

        fc_outw_prior = Normal(loc=torch.zeros_like(self.NN.fc_out.weight), scale=torch.ones_like(self.NN.fc_out.weight))
        fc_outb_prior = Normal(loc=torch.zeros_like(self.NN.fc_out.bias), scale=torch.ones_like(self.NN.fc_out.bias))
        
        priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  
                    'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
                    'fc_out.weight': fc_outw_prior, 'fc_out.bias': fc_outb_prior}

        # the function pyro.random_module() converts parameters of our neural network (weights and biases) 
        # into random variables that have the initial (prior) probability distribution given by fc_priors
        lifted_module = pyro.random_module("module", self.NN, priors)

        # sample a regressor (which also samples w and b)
        lifted_reg_model = lifted_module()
        lhat = log_softmax(lifted_reg_model(x_data)) 

        #with pyro.plate("data", len(x_data)):
        pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

    def guide(self, x_data, y_data):
        """"
        The guide is the approximation of the posterior distribution: the variational distribution
        http://pyro.ai/examples/intro_long.html#Background:-variational-inference
        See the video: https://youtu.be/DYRK0-_K2UU
        """
        
        softplus = torch.nn.Softplus()

        # First layer weight distribution priors
        fc1w_mu = torch.randn_like(self.NN.fc1.weight)
        fc1w_sigma = torch.randn_like(self.NN.fc1.weight)
        fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
        fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
        fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)

        # First layer bias distribution priors
        fc1b_mu = torch.randn_like(self.NN.fc1.bias)
        fc1b_sigma = torch.randn_like(self.NN.fc1.bias)
        fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
        fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
        fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)

        # Second layer weight distribution priors
        fc2w_mu = torch.randn_like(self.NN.fc2.weight)
        fc2w_sigma = torch.randn_like(self.NN.fc2.weight)
        fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
        fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma))
        fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)

        # Second layer bias distribution priors
        fc2b_mu = torch.randn_like(self.NN.fc2.bias)
        fc2b_sigma = torch.randn_like(self.NN.fc2.bias)
        fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
        fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma))
        fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)

        # Third layer weight distribution priors
        fc_outw_mu = torch.randn_like(self.NN.fc_out.weight)
        fc_outw_sigma = torch.randn_like(self.NN.fc_out.weight)
        fc_outw_mu_param = pyro.param("fc_outw_mu", fc_outw_mu)
        fc_outw_sigma_param = softplus(pyro.param("fc_outw_sigma", fc_outw_sigma))
        fc_outw_prior = Normal(loc=fc_outw_mu_param, scale=fc_outw_sigma_param)

        # Third layer bias distribution priors
        fc_outb_mu = torch.randn_like(self.NN.fc_out.bias)
        fc_outb_sigma = torch.randn_like(self.NN.fc_out.bias)
        fc_outb_mu_param = pyro.param("fc_outb_mu", fc_outb_mu)
        fc_outb_sigma_param = softplus(pyro.param("fc_outb_sigma", fc_outb_sigma))
        fc_outb_prior = Normal(loc=fc_outb_mu_param, scale=fc_outb_sigma_param)
        
        priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 
                    'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
                    'fc_out.weight': fc_outw_prior, 'fc_out.bias': fc_outb_prior}

        lifted_module = pyro.random_module("module", self.NN, priors)
        
        return lifted_module()

    def initTransform(self):

        transform = A.Compose(
            [
                A.SmallestMaxSize(max_size=260),
                A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                A.RandomCrop(height=224, width=224),
                A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )

        return transform

    def predict(self, x):
        """
        for guide() after optimization iterations, the distribution given by the parameter values 
        approximate the true posterior and so we can use it for predictions. 

        for each prediction, we are sampling a new set of weights and parameters num_samples times. 
        This effectively means that we are sampling a new neural network 10 times for making one prediction.

        In our case, to make a prediction, we are averaging final layer output values of the 10 sampled nets for 
        the given input and taking the max activation value as the predicted digit.
        """

        sampled_models = [self.guide(None, None) for _ in range(self.num_samples)]
        yhats = [model(x).data for model in sampled_models]
        mean = torch.mean(torch.stack(yhats), 0)
        return np.argmax(mean.numpy(), axis=1)

    def trainingLoop(self):

        # Specify the function to optimize
        elbo = pyro.infer.Trace_ELBO()
        svi = pyro.infer.SVI(self.model, self.guide, self.optimizer, loss=elbo)

        # Prepare the dataset
        list_images = glob.glob(self.data_path + "/*.jpg")
        
        # Make a training and testing set
        train_images = random.sample(list_images, int(len(list_images) * 0.8))
        test_images = [item for item in list_images if item not in train_images]

        trainset = ImgLoader(train_images, transform=self.transform)
        testset = ImgLoader(test_images, transform=self.transform)

        trainLoader = DataLoader(trainset,
                                batch_size = self.batch_size, 
                                shuffle=True,
                                pin_memory=self.is_cuda)
        # Train the network
        for epoch in range(self.num_epochs):  # loop over the dataset multiple times

            training_loss = 0.0
            nb_train_batch = 0

            metrics_trn = torch.zeros(self.METRICS_SIZE, len(trainLoader.dataset), device=self.device)
            #metrics_val = torch.zeros(self.METRICS_SIZE, len(testLoader.dataset), device=self.device)

            for batch_ndx, tup_ndx in enumerate(trainLoader, 0):

                start_ndx = batch_ndx * self.batch_size
                end_ndx = start_ndx + self.batch_size

                # get the inputs; data is a list of [inputs, labels]
                imgs, labels = tup_ndx
                imgs = imgs.permute(0,3,1,2)

                # Imgs and labels to GPU if possible
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)

                # forward + backward + optimize
                loss = svi.step(imgs, labels)

                # Compute the metrics
                outputs = self.NN(imgs)
                _, predicted = torch.max(outputs, dim=1)

                metrics_trn[self.LABELS_NDX, start_ndx:end_ndx] = labels
                metrics_trn[self.PREDICTED_NDX, start_ndx:end_ndx] = predicted

                # Get accuracy
                correct = torch.eq(metrics_trn[self.LABELS_NDX], metrics_trn[self.PREDICTED_NDX]).sum()
                accuracy = correct / len(metrics_trn[self.LABELS_NDX])

                training_loss += loss 
                nb_train_batch += 1
     
        print("Epoch: {}, Training loss: {}, Training accuracy: {}".format(epoch, training_loss / nb_train_batch, accuracy))
        print('Finished Training')

        # Save the model
        torch.save(self.model.state_dict(), self.save_path)

if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--data_path",
                        help="Path to the folder containing the data (train and test set will be sampled)",
                        required=True,
                        type=str
    )
    
    parser.add_argument("--model_path",
                        help="Path to the model weights",
                        required=True,
                        type=str
    )

    parser.add_argument("--save_path",
                        help="Number of epochs to train the model on",
                        required=True,
                        type=str
    )

    parser.add_argument("--num_epochs",
                        help="Number of epochs to train the model on",
                        required=True,
                        type=int
    )

    cli_args = parser.parse_args()

    trainingApp(cli_args.data_path, cli_args.model_path, cli_args.save_path, cli_args.num_epochs).trainingLoop()

I also computed the batch_shape at each fully connected layers and I get:

Guide shape
torch.Size([1000, 512]) #fc1 weights
torch.Size([1000]) # fc1 biases
torch.Size([512, 1000]) # fc2 weights
torch.Size([512]) # fc2 bias 
torch.Size([2, 512]) # fc_out weights
torch.Size([2, 512]) # fc_out bias
Model shape
torch.Size([1000, 512])
torch.Size([1000])
torch.Size([512, 1000])
torch.Size([512])
torch.Size([2, 512])
torch.Size([2])

Hi @bencr, you are seeing this error because you need the event_shapes of your weight priors to match the shapes of the weights, not the batch_shapes. For example, you would need to add .to_event(2) to your linear layer’s weight priors:

fc1w_prior = Normal(
    loc=torch.zeros_like(self.NN.fc1.weight),
    scale=torch.ones_like(self.NN.fc1.weight)
).to_event(2)

Note also that we generally recommend using TyXe for building Bayesian neural networks with Pyro, since there can be quite a lot of boilerplate involved in specifying BNNs and setting appropriate priors and initial guide parameter values by hand can be tricky.

If for some reason you do need to use raw Pyro, I suggest using the pyro.nn.PyroModule API, which simplifies working with torch.nn.Modules; a very simple example is shown in the Bayesian regression tutorial.

Thank you so much for your answer! Thanks to this tutorial on tensor shape with probability distribution I understand better now. Now the model runs but with another issue.

I think I started too ambitious for my first approach with Pyro and I simplified the script. Basically I am now loading a pre-trained ResNet50 from torch.hub and I convert its last FC layer as a probabilistic layer.

Nevertheless, during training the loss is very high (~5000) and the accuracy very low (50%) for a pre-trained model (I am using the kaggle cat / dog dataset which is completely balanced). What could be the cause of such bad training? Could the priors be mispecified?

Here is the code I am using:

# Load pre-trained model ResNet and add FC layers on top
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True # Fix bug for Pytorch 1.9
NN = torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True)

# Freeze the layers
for param in NN.parameters():
    param.requires_grad = False

NN.fc = nn.Linear(2048,2)
NN.to(device)

def model(x_data, y_data):

    log_softmax = nn.LogSoftmax(dim=1)

    fcw_prior = Normal(loc=torch.zeros_like(NN.fc.weight).to(device), scale=torch.ones_like(NN.fc.weight).to(device)).to_event(2)
    fcb_prior = Normal(loc=torch.zeros_like(NN.fc.bias).to(device), scale=torch.ones_like(NN.fc.bias).to(device)).to_event(1)
    
    priors = {'fc.weight': fcw_prior, 'fc.bias': fcb_prior}

    lifted_module = pyro.random_module("module", NN, priors)

    lifted_reg_model = lifted_module()
    lhat = log_softmax(lifted_reg_model(x_data)) 

    with pyro.plate("data", len(x_data)):
        pyro.sample("obs", Categorical(logits=lhat), obs=y_data)


def guide(x_data, y_data):
      
      softplus = torch.nn.Softplus()

      # First layer weight distribution priors
      fcw_mu = torch.randn_like(NN.fc.weight).to(device)
      fcw_sigma = torch.randn_like(NN.fc.weight).to(device)
      fcw_mu_param = pyro.param("fcw_mu", fcw_mu)
      fcw_sigma_param = softplus(pyro.param("fcw_sigma", fcw_sigma))
      fcw_prior = Normal(loc=fcw_mu_param, scale=fcw_sigma_param).to_event(2)

      # First layer bias distribution priors
      fcb_mu = torch.randn_like(NN.fc.bias).to(device)
      fcb_sigma = torch.randn_like(NN.fc.bias).to(device)
      fcb_mu_param = pyro.param("fcb_mu", fcb_mu)
      fcb_sigma_param = softplus(pyro.param("fcb_sigma", fcb_sigma))
      fcb_prior = Normal(loc=fcb_mu_param, scale=fcb_sigma_param).to_event(1)

      priors = {'fc.weight': fcw_prior, 'fc.bias': fcb_prior}

      lifted_module = pyro.random_module("module", NN, priors)
      
      return lifted_module()

And here is the very simple training loop:

optimizer = pyro.optim.Adam({"lr": 0.001})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optimizer, loss=elbo)

# Prepare the dataset
list_images = glob.glob(data_path + "/*.jpg")

trainset = ImgLoader(list_images, transform=transform)

trainLoader = DataLoader(trainset,
                        batch_size = 128, 
                        shuffle=True,
                        pin_memory=is_cuda)
# Train the network
for epoch in range(10):  # loop over the dataset multiple times

    training_loss = 0.0
    nb_train_batch = 0

    correct = 0
    total = 0

    for _, tup_ndx in enumerate(trainLoader, 0):

        # get the inputs; data is a list of [inputs, labels]
        imgs, labels = tup_ndx
        imgs = imgs.permute(0,3,1,2)

        # Imgs and labels to GPU if possible
        imgs = imgs.to(device)
        labels = labels.to(device)

        # forward + backward + optimize
        loss = svi.step(imgs, labels)

        # Compute the metrics
        outputs = NN(imgs)
        _, predicted = torch.max(outputs, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        training_loss += loss 
        nb_train_batch += 1

    print("Epoch: {}, Training loss: {}, Training accuracy: {}".format(epoch, training_loss / nb_train_batch, correct / total))
print('Finished Training')

to first approximation bayesian neural networks don’t work, especially if you’re not familiar with all the technical tricks and gotchas.

you’re probably getting killed by some combination of improper invocation of data subsampling in pyro and/or the lack of the local reparameterization trick.

it’s probably best to use TyXe, which can help deal with some of these technical issues.

I wanted to avoid using very high level packages but with all the moving parts you’re right it might be simpler to use TyXe for now.

I have been implementing the code shown in this paper, unfortunately I was unsuccessful. When using a homoskedastic Gaussian for the likelihood function I get this error:

ValueError: Error while computing log_prob at site 'likelihood.data':
Value is not broadcastable with batch_shape+event_shape: torch.Size([128]) vs torch.Size([128, 2]).

...

Sample Sites:                      
              net.fc.weight dist             |   2 512
                           value             |   2 512
                        log_prob             |        
                net.fc.bias dist             |   2    
                           value             |   2    
                        log_prob             |        
            likelihood.data dist             | 128   2
                           value             | 128    

My code is very minimalistic here:

import torch
import torch.nn as nn
import albumentations as A

import pyro
from pyro.distributions import Normal

from torch.utils.data import DataLoader

from TyXe.tyxe.priors import IIDPrior
from TyXe.tyxe.likelihoods import HomoskedasticGaussian
from TyXe.tyxe import VariationalBNN

from utils.loader import ImgLoader

import glob

transform = A.Compose(
            [
                A.SmallestMaxSize(max_size=260),
                A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                A.RandomCrop(height=224, width=224),
                A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )

# Prepare the dataset
data_path = "/Data/train"
list_images = glob.glob(data_path + "/*.jpg")

trainset = ImgLoader(list_images, transform=transform)

trainLoader = DataLoader(trainset,
                        batch_size = 128, 
                        shuffle=True)

# Load pre-trained model ResNet and add FC layers on top
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True # Fix bug for Pytorch 1.9
NN = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
NN.to('cpu')

# Freeze the layers
for param in NN.parameters():
    param.requires_grad = False

NN.fc = nn.Linear(512,2)

ll_prior = IIDPrior(Normal(0, 1), expose_all=False, expose_modules=[NN.fc])
likelihood = HomoskedasticGaussian(len(list_images), event_dim=2, scale=1)
lr_guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal
bnn = VariationalBNN(NN, ll_prior, likelihood, lr_guide)
optim = pyro.optim.Adam({"lr": 1e-3})

# fit the model
bnn.fit(trainLoader, optim, 5)

sorry but i don’t use tyxe myself so you should probably ask here

1 Like

Hi, just to let you know that I solved the problem - which was pretty simple, I was using a Gaussian instead of a Categorical likelihood for my data. The model now can be trained and returns a coherent accuracy (~96% for only 3 epochs of training).

In any case thank you so much for your help!