 # Pyro Conversion of Complex PyTorch Classification Neural Network Help

I have gone through all of the tutorials and the other topics on this forum, but can’t seem to extrapolate how to build a complex Bayesian NN for classification. Below is what all of my research has afforded me. Please help me correct my misunderstandings.

`````` import torch
import torch.nn as nn
import numpy as np
import pyro
from pyro.distributions import Normal, Categorica
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import SGD,ReduceLROnPlateau
from pyro.nn.module import PyroModule, to_pyro_module_, PyroSample
from pyro.infer.autoguide import AutoDiagonalNormal
``````

Here I design my complex model in PyTorch. The tutorials I read only had Linear layers, and only had 2 or so layers.

`````` class MyTorchNet(nn.Module):
def __init__(self, num_classes, hs, input_num):
super(MyTorchNet, self).__init__()
layers = []  #Build Sequential
layers.append(nn.BatchNorm2d(hs))
layers.append(nn.MaxPool2d(2, stride=2))
layers.append(nn.ReLU(True))
layers.append(nn.BatchNorm2d(2*hs))
layers.append(nn.MaxPool2d(2, stride=2))
layers.append(nn.ReLU(True))
layers.append(FlattenLayer(num_classes))
layers.append(nn.Softmax(1))
self.cnn = nn.Sequential(*layers)
def forward(self,x):
return self.cnn(x)
``````

This is my attempt at Baysian-fying the above network.

`````` class BNN(PyroModule):
def __init__(self, num_classes, hs, input_num):
super().__init__()
self.net = MyTorchNet( num_classes, hs, input_num)  # Initialize torch net
to_pyro_module_(self.net)  # Turn it into a Pyro Module
for m in self.net.cnn: # Indicate that Conv layers should be pulled from a Normal Dist
if isinstance(m,nn.Conv2d):
m.weight = PyroSample(prior=Normal(m.weight,m.weight))
m.bias = PyroSample(prior=Normal(m.bias,m.bias))
self.model = self.bayes_model
self.guide = AutoDiagonalNormal(self.model)
def forward(self,x,y):  # Not sure if this is right
return self.model(x,y)
def bayes_model(self,x_data, y_data):   # This might be duplicating my work in the init function?
# define priors
get_params = lambda w: (torch.zeros_like(w), torch.ones_like(w))
priors = {name:Normal(*get_params(w)) for name, w in
self.net.named_parameters()}
# lift onto a random pyro module
lifted_module = pyro.random_module("module", self.net, priors)
lifted_reg_model = lifted_module()

# define rest of model with likelihood
yhat = lifted_reg_model(x_data)
pyro.sample("obs", Categorical(yhat), obs=y_data)
return yhat
def predict(self,x): # Function to help get predictions for testing model
predictive = Predictive(self.model, self.guide(), num_samples=10, return_sites=("obs","_RETURN"))
pred_dict = predictive(x,None)
``````

This is how I’m training it using a dataloader

``````model = BNN(10,5,48).cuda()
scheduler = ReduceLROnPlateau({'optimizer': SGD,'optim_args':{'lr':0.00002,'momentum':0.9}})
svi = SVI(model.model, model.guide, scheduler, loss=Trace_ELBO())
pyro.clear_param_store()
for j in range(100):
total_loss = 0