Trying to create bayensian convnets using pyro


#1

Hi,
I am trying to create bayesian CNNs using pyro model-guide framework. I trained the network for few epochs and its taking a lot of time, but the loss is being reduced to something like 50(categorical loss). Please find the code below for the same. Could anyone let me know what am I am doing wrong.

class NN(nn.Module):

def __init__(self):
    super(NN, self).__init__()
    self.conv1 = nn.Conv2d(1,32, 3)
    self.conv2 = nn.Conv2d(32,64,3)
    self.dropout=nn.Dropout(0.25)
    self.fc1 = nn.Linear(64*12*12, 128)
    self.dropout1=nn.Dropout(0.5)
    self.out = nn.Linear(128, 10)

def forward(self, x):
    # Max pooling over a (2, 2) window
    output=F.relu(self.conv1(x))
    output = F.max_pool2d(F.relu(self.conv2(output)), (2, 2))
    output=self.dropout(output)
    
    # If the size is a square you can only specify a single number
    output = output.view(-1, self.num_flat_features(output))
    output = F.relu(self.fc1(output))
    output=self.dropout(output)
    output =self.out(output)
    return output

def num_flat_features(self, x):
    size = x.size()[1:]  # all dimensions except the batch dimension
                         # 0 is the batch dimensions
    num_features = 1
    for s in size:
        num_features *= s
    return num_features

def model(x_data, y_data):

    conv1w_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight))
    conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias))

    conv2w_prior = Normal(loc=torch.zeros_like(net.conv2.weight), scale=torch.ones_like(net.conv2.weight))
    conv2b_prior = Normal(loc=torch.zeros_like(net.conv2.bias), scale=torch.ones_like(net.conv2.bias))

    fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
    fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))

    outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
    outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))

    priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior, 'conv2.weight': conv2w_prior, 'conv2.bias': 
    conv2b_prior,'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'out.weight': outw_prior, 'out.bias': outb_prior,}
    lifted_module = pyro.random_module("module", net, priors)
    lifted_reg_model = lifted_module()

    lhat = log_softmax(lifted_reg_model(x_data))
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

softplus = torch.nn.Softplus()

def guide(x_data, y_data):

    #convolution layer weights
    conv1w_mu = torch.randn_like(net.conv1.weight)
    conv1w_sigma = torch.randn_like(net.conv1.weight)
    conv1w_mu_param = pyro.param("conv1w_mu", conv1w_mu)
    conv1w_sigma_param = softplus(pyro.param("conv1w_sigma", conv1w_sigma))
    conv1w_prior = Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param)
    # First layer bias distribution priors
    conv1b_mu = torch.randn_like(net.conv1.bias)
    conv1b_sigma = torch.randn_like(net.conv1.bias)
    conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu)
    conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma))
    conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param)

     #convolution layer weights
    conv2w_mu = torch.randn_like(net.conv2.weight)
    conv2w_sigma = torch.randn_like(net.conv2.weight)
    conv2w_mu_param = pyro.param("conv2w_mu", conv2w_mu)
    conv2w_sigma_param = softplus(pyro.param("conv2w_sigma", conv2w_sigma))
    conv2w_prior = Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param)
    # First layer bias distribution priors
    conv2b_mu = torch.randn_like(net.conv2.bias)
    conv2b_sigma = torch.randn_like(net.conv2.bias)
    conv2b_mu_param = pyro.param("conv2b_mu", conv2b_mu)
    conv2b_sigma_param = softplus(pyro.param("conv2b_sigma", conv2b_sigma))
    conv2b_prior = Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param)

    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param) 
    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,'conv2.weight': conv2w_prior, 'conv2.bias': 
    conv2b_prior,'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior,}
    lifted_module = pyro.random_module("module", net, priors)
    return lifted_module()

num_iterations = 5
loss = 0

for j in range(num_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        # calculate the loss and take a gradient step
        loss += svi.step(data[0], data[1])
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train

print("Epoch ", j, " Loss ", total_epoch_loss_train)

The prediction accuracy is pretty low with conv-nets. The model is predicting the test as well as training data to be chance.
With one FC bayesian layer, on MNIST, it gives 88% accuracy.


#2

bayesian neural networks are an active area of research. generally speaking, variational approaches don’t work very well/reliably. this is especially true if the neural network has a large number of parameters (like a convnet). and things get even worse if you do things more or less naively (as would be the case if you use pyro.random_module). so i’d probably advise against trying to build bayesian neural networks in pyro (or in general) unless you’re an expert on the many difficult aspects of the problem.


#3

Could you please elaborate on the [quote=“martinjankowiak, post:2, topic:563”]
many difficult aspects of the problem
[/quote]
?

I am trying to do the same thing as Sareen1331 did. But I am using a ResNet with about 300,000 weights/biases. Is that considered large? I tried different learning rate schedules, but so far the performance was 70%, while the same deterministic neural network was about 99%. Thanks.


#4

(very) roughly speaking, you probably shouldn’t expect a naive variational inference approach to bayesian neural networks to work unless you have more data points than weights: N>>W


#5

Could you please explain a bit what is a non-naive variational inference approach based on Pyro? Thanks.


#6

this is an active area of research that can’t be summarized very easily. see, for example, this paper from a few years ago: http://papers.nips.cc/paper/5666-variational-dropout-and-the-local-reparameterization-trick