Trying to create bayensian convnets using pyro

Hi,
I am trying to create bayesian CNNs using pyro model-guide framework. I trained the network for few epochs and its taking a lot of time, but the loss is being reduced to something like 50(categorical loss). Please find the code below for the same. Could anyone let me know what am I am doing wrong.

class NN(nn.Module):

def __init__(self):
    super(NN, self).__init__()
    self.conv1 = nn.Conv2d(1,32, 3)
    self.conv2 = nn.Conv2d(32,64,3)
    self.dropout=nn.Dropout(0.25)
    self.fc1 = nn.Linear(64*12*12, 128)
    self.dropout1=nn.Dropout(0.5)
    self.out = nn.Linear(128, 10)

def forward(self, x):
    # Max pooling over a (2, 2) window
    output=F.relu(self.conv1(x))
    output = F.max_pool2d(F.relu(self.conv2(output)), (2, 2))
    output=self.dropout(output)
    
    # If the size is a square you can only specify a single number
    output = output.view(-1, self.num_flat_features(output))
    output = F.relu(self.fc1(output))
    output=self.dropout(output)
    output =self.out(output)
    return output

def num_flat_features(self, x):
    size = x.size()[1:]  # all dimensions except the batch dimension
                         # 0 is the batch dimensions
    num_features = 1
    for s in size:
        num_features *= s
    return num_features

def model(x_data, y_data):

    conv1w_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight))
    conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias))

    conv2w_prior = Normal(loc=torch.zeros_like(net.conv2.weight), scale=torch.ones_like(net.conv2.weight))
    conv2b_prior = Normal(loc=torch.zeros_like(net.conv2.bias), scale=torch.ones_like(net.conv2.bias))

    fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
    fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))

    outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
    outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))

    priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior, 'conv2.weight': conv2w_prior, 'conv2.bias': 
    conv2b_prior,'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'out.weight': outw_prior, 'out.bias': outb_prior,}
    lifted_module = pyro.random_module("module", net, priors)
    lifted_reg_model = lifted_module()

    lhat = log_softmax(lifted_reg_model(x_data))
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

softplus = torch.nn.Softplus()

def guide(x_data, y_data):

    #convolution layer weights
    conv1w_mu = torch.randn_like(net.conv1.weight)
    conv1w_sigma = torch.randn_like(net.conv1.weight)
    conv1w_mu_param = pyro.param("conv1w_mu", conv1w_mu)
    conv1w_sigma_param = softplus(pyro.param("conv1w_sigma", conv1w_sigma))
    conv1w_prior = Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param)
    # First layer bias distribution priors
    conv1b_mu = torch.randn_like(net.conv1.bias)
    conv1b_sigma = torch.randn_like(net.conv1.bias)
    conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu)
    conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma))
    conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param)

     #convolution layer weights
    conv2w_mu = torch.randn_like(net.conv2.weight)
    conv2w_sigma = torch.randn_like(net.conv2.weight)
    conv2w_mu_param = pyro.param("conv2w_mu", conv2w_mu)
    conv2w_sigma_param = softplus(pyro.param("conv2w_sigma", conv2w_sigma))
    conv2w_prior = Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param)
    # First layer bias distribution priors
    conv2b_mu = torch.randn_like(net.conv2.bias)
    conv2b_sigma = torch.randn_like(net.conv2.bias)
    conv2b_mu_param = pyro.param("conv2b_mu", conv2b_mu)
    conv2b_sigma_param = softplus(pyro.param("conv2b_sigma", conv2b_sigma))
    conv2b_prior = Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param)

    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param) 
    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,'conv2.weight': conv2w_prior, 'conv2.bias': 
    conv2b_prior,'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior,}
    lifted_module = pyro.random_module("module", net, priors)
    return lifted_module()

num_iterations = 5
loss = 0

for j in range(num_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        # calculate the loss and take a gradient step
        loss += svi.step(data[0], data[1])
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train

print("Epoch ", j, " Loss ", total_epoch_loss_train)

The prediction accuracy is pretty low with conv-nets. The model is predicting the test as well as training data to be chance.
With one FC bayesian layer, on MNIST, it gives 88% accuracy.

1 Like

bayesian neural networks are an active area of research. generally speaking, variational approaches don’t work very well/reliably. this is especially true if the neural network has a large number of parameters (like a convnet). and things get even worse if you do things more or less naively (as would be the case if you use pyro.random_module). so i’d probably advise against trying to build bayesian neural networks in pyro (or in general) unless you’re an expert on the many difficult aspects of the problem.

Could you please elaborate on the [quote=“martinjankowiak, post:2, topic:563”]
many difficult aspects of the problem
[/quote]
?

I am trying to do the same thing as Sareen1331 did. But I am using a ResNet with about 300,000 weights/biases. Is that considered large? I tried different learning rate schedules, but so far the performance was 70%, while the same deterministic neural network was about 99%. Thanks.

1 Like

(very) roughly speaking, you probably shouldn’t expect a naive variational inference approach to bayesian neural networks to work unless you have more data points than weights: N>>W

Could you please explain a bit what is a non-naive variational inference approach based on Pyro? Thanks.

1 Like

this is an active area of research that can’t be summarized very easily. see, for example, this paper from a few years ago: Variational Dropout and the Local Reparameterization Trick