Bad performance Bayesian Convolutional Neural Network

Hello Pyro community,

I’m trying to build a Bayesian CNN for MNIST classification using Pyro, but despite seeing the ELBO loss decrease to around 10 during training, the model’s predictive accuracy remains at chance level (~10%). Could you help me understand why the loss improves while performance doesn’t, and suggest potential fixes?

Code Overview:

import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample, PyroParam
from pyro.infer.predictive import Predictive
from pyro.infer.autoguide import AutoDiagonalNormal, AutoNormal
import torch.nn as nn
from pyro.distributions import constraints
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

import os

# Define model class
class BayesianCNN(PyroModule):
    def __init__(self):
        super().__init__()
        
        self.weight_scale = 1

        # First convolutional layer
        self.conv1 = PyroModule[nn.Conv2d](in_channels=1, out_channels=6, kernel_size=5)
    
        self.conv1.weight = PyroSample(
            dist.Normal(0., self.weight_scale).expand(self.conv1.weight.shape).to_event(self.conv1.weight.dim())
        )
        self.conv1.bias = PyroSample(
            dist.Normal(0., 1).expand(self.conv1.bias.shape).to_event(self.conv1.bias.dim())
        )
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Activation
        self.relu = nn.ReLU()
        
        # Second convolutional layer
        self.conv2 = PyroModule[nn.Conv2d](in_channels=6, out_channels=16, kernel_size=5)
        self.conv2.weight = PyroSample(
            dist.Normal(0., self.weight_scale).expand(self.conv2.weight.shape).to_event(self.conv2.weight.dim())
        )
        self.conv2.bias = PyroSample(
            dist.Normal(0., 1).expand(self.conv2.bias.shape).to_event(self.conv2.bias.dim())
        )
        
        # Fully connected layers
        self.fc1 = PyroModule[nn.Linear](16 * 4 * 4, 120)
        self.fc1.weight = PyroSample(
            dist.Normal(0., self.weight_scale).expand(self.fc1.weight.shape).to_event(self.fc1.weight.dim())
        )
        self.fc1.bias = PyroSample(
            dist.Normal(0., 1).expand(self.fc1.bias.shape).to_event(self.fc1.bias.dim())
        )
        
        self.fc2 = PyroModule[nn.Linear](120, 84)
        self.fc2.weight = PyroSample(
            dist.Normal(0., self.weight_scale).expand(self.fc2.weight.shape).to_event(self.fc2.weight.dim())
        )
        self.fc2.bias = PyroSample(
            dist.Normal(0., 1).expand(self.fc2.bias.shape).to_event(self.fc2.bias.dim())
        )
        
        self.fc3 = PyroModule[nn.Linear](84, 10)
        self.fc3.weight = PyroSample(
            dist.Normal(0., self.weight_scale).expand(self.fc3.weight.shape).to_event(self.fc3.weight.dim())
        )
        self.fc3.bias = PyroSample(
            dist.Normal(0., 1).expand(self.fc3.bias.shape).to_event(self.fc3.bias.dim())
        )

    
    def forward(self, x, y=None):
        # Forward pass through the network
        x = self.pool(self.relu(self.conv1(x)))
        
        x = self.pool(self.relu(self.conv2(x)))
        
        x = x.view(-1, 16 * 4 * 4)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        
        # Logsoftmax for classification
        logits = nn.functional.log_softmax(x, dim=1)
        
        with pyro.plate("data", x.shape[0]):
            pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        
        return logits

# Prepare the data and dataloaders
# First set transformations
transformations = transforms.Compose([
    transforms.Resize(40),
    transforms.CenterCrop(28),
    transforms.ToTensor()
    ])

# Get the datasets
mnist_path = os.path.join("..",  "..", "data", "datasets")
train_data = MNIST(root=mnist_path, train=True, transform=transformations, download=True)
test_data = MNIST(root=mnist_path, train=False, transform=transformations, download=True)

# Create dataloaders
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)

# Create the model
model = BayesianCNN()

# Use autoguide

guide = AutoNormal(model)

# Set up the optimizer
optimizer = pyro.optim.Adam({"lr": 0.01})

# Set up the SVI object
from pyro.infer import SVI, Trace_ELBO
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# Training loop
losses = []
def train(train_loader, num_epochs=5):
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_id, (images, labels) in enumerate(train_loader):
            # Calculate the loss and take a gradient step
            loss = svi.step(images, labels)
            total_loss += loss
            
            if batch_id % 100 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_id}, Loss {loss/len(images)}")
        
        # Average loss for the epoch
        avg_loss = total_loss / len(train_loader.dataset)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1} completed. Average loss: {avg_loss}")
    return losses

losses = train(train_loader=train_loader, num_epochs=5)

# Quick performance test
model.eval()
test_im, test_lab = next(iter(test_loader))
print(test_lab)

predictive = pyro.infer.Predictive(model=model, guide=guide, num_samples=100)
sample_pred = predictive(test_im)

sum(sample_pred["obs"] == test_lab) # Check how many out of 100 were correct

Observed Behavior:

  • ELBO loss decreases consistently over 5 epochs (final avg ≈10)
  • Predictive checks using pyro.infer.Predictive show 10% accuracy (100 samples)

Questions:

  1. Why would ELBO improve while predictive performance stays at chance?
  2. Could this indicate the guide is collapsing to prior despite loss decrease
  3. Are there specific diagnostic tools in Pyro to detect this?
  4. What are potential fixes?

Any insights into this disconnect between ELBO and model performance would be greatly appreciated!

Plot of the loss against the svi steps.