Pyro Conversion of Complex PyTorch Classification Neural Network Help

Hello,

I have gone through all of the tutorials and the other topics on this forum, but can’t seem to extrapolate how to build a complex Bayesian NN for classification. Below is what all of my research has afforded me. Please help me correct my misunderstandings.

 import torch
 import torch.nn as nn
 import numpy as np
 import pyro
 from pyro.distributions import Normal, Categorica
 from pyro.infer import SVI, Trace_ELBO, Predictive
 from pyro.optim import SGD,ReduceLROnPlateau
 from pyro.nn.module import PyroModule, to_pyro_module_, PyroSample
 from pyro.infer.autoguide import AutoDiagonalNormal

Here I design my complex model in PyTorch. The tutorials I read only had Linear layers, and only had 2 or so layers.

 class MyTorchNet(nn.Module):
     def __init__(self, num_classes, hs, input_num):
         super(MyTorchNet, self).__init__()
         layers = []  #Build Sequential
         layers.append(nn.Conv2d(input_num, hs, kernel_size=3, stride=1, padding=1))
         layers.append(nn.BatchNorm2d(hs))
         layers.append(nn.MaxPool2d(2, stride=2))
         layers.append(nn.ReLU(True))
         layers.append(nn.Conv2d(hs, 2*hs, kernel_size=3, stride=1, padding=1))
         layers.append(nn.BatchNorm2d(2*hs))
         layers.append(nn.MaxPool2d(2, stride=2))
         layers.append(nn.ReLU(True))
         layers.append(nn.Conv2d(2*hs, num_classes, kernel_size=1, stride=1, padding=0))
         layers.append(nn.AdaptiveAvgPool2d(1))
         layers.append(FlattenLayer(num_classes))
         layers.append(nn.Softmax(1))
         self.cnn = nn.Sequential(*layers)
     def forward(self,x):
         return self.cnn(x)

This is my attempt at Baysian-fying the above network.

 class BNN(PyroModule):
     def __init__(self, num_classes, hs, input_num):
         super().__init__()
         self.net = MyTorchNet( num_classes, hs, input_num)  # Initialize torch net
         to_pyro_module_(self.net)  # Turn it into a Pyro Module
         for m in self.net.cnn: # Indicate that Conv layers should be pulled from a Normal Dist
             if isinstance(m,nn.Conv2d):
                 m.weight = PyroSample(prior=Normal(m.weight,m.weight))
                 m.bias = PyroSample(prior=Normal(m.bias,m.bias))
         self.model = self.bayes_model 
         self.guide = AutoDiagonalNormal(self.model)
     def forward(self,x,y):  # Not sure if this is right
         return self.model(x,y)
     def bayes_model(self,x_data, y_data):   # This might be duplicating my work in the init function?
         # define priors
         get_params = lambda w: (torch.zeros_like(w), torch.ones_like(w))
         priors = {name:Normal(*get_params(w)) for name, w in 
                   self.net.named_parameters()}
         # lift onto a random pyro module
         lifted_module = pyro.random_module("module", self.net, priors)
         lifted_reg_model = lifted_module()

         # define rest of model with likelihood
         yhat = lifted_reg_model(x_data)
         pyro.sample("obs", Categorical(yhat), obs=y_data)
         return yhat
     def predict(self,x): # Function to help get predictions for testing model
         predictive = Predictive(self.model, self.guide(), num_samples=10, return_sites=("obs","_RETURN"))
         pred_dict = predictive(x,None)
         return torch.mode(pred_dict['obs'],0)[0]

This is how I’m training it using a dataloader

model = BNN(10,5,48).cuda()
scheduler = ReduceLROnPlateau({'optimizer': SGD,'optim_args':{'lr':0.00002,'momentum':0.9}})
svi = SVI(model.model, model.guide, scheduler, loss=Trace_ELBO())
pyro.clear_param_store()
for j in range(100):
    total_loss = 0
    for x,y in loader:
        x,y = x.cuda(),y.cuda()
        loss = svi.step(x,y)
        total_loss += loss 
    scheduler.step(total_loss / len(loader.dataset))

Is this the right way to do it? What things in it are wrong/not needed?
The inputs are batched 1x48x48 images and I expect the output to be a number between 1 and 10 (10 classes). Right now it runs, but the loss is NaN, and it breaks on my prediction function.

I’m using Pyro v 1.3.1 and torch v 1.5.0