Bayesian MNIST Classification

Hi, I’m trying to implement a simple Bayesian model that performs classification on the MNIST dataset (something along the lines of this Edward-based tutorial) which provides me a measurement of the certainty of the model given an unseen image at test time.

For such purpose, I have this simple Pytorch model (that achieves ~93% acc when trained for 1 epoch with SGD and standard hyperparameters):

class Net(torch.nn.Module):
    def __init__(self, isCuda):
        super(Net, self).__init__()
        self.layer1 = torch.nn.Sequential(torch.nn.Conv1d(28*28, 64, kernel_size=1, bias=False ), torch.nn.BatchNorm1d(64))
        self.fc = torch.nn.Linear(64, 10, bias=None)
        self.relu = torch.nn.ReLU()
        self.isCuda = isCuda
    def forward(self, x):
        x = x.view(-1,28*28,1)
        x = self.relu(self.layer1(x))
        return  self.fc(torch.squeeze(x,2))
    def evaulateOnDataset(self, dataset):
        """ Here dataset is the validation/test set from a particular datasest (e.g. MNIST, CIFAR10)
        with torch.no_grad(): # this is needed when evaluating the validation set
            correct = 0
            total = 0
            for data in dataset:
                images, labels = data
                if self.isCuda:
                    outputs = self.forward(torch.tensor(images).cuda())
                    outputs = self.forward(torch.tensor(images))

                _, predicted = torch.max(, 1)

                total += labels.size(0)
                # predicted is of type torch.cuda.LongTensor
                predicted = predicted.cpu() # we need to convert it to a nomal torch.LongTensor data type
                correct += (predicted == labels).sum()

        return total, correct

And this is my SVI setup (model + guide + optim) that I’ve put together following the tutorials provided in the pyro documentation. I tried to made the necessary changes to adapt both model and guide to a classification application (unlike in the tutorials for VAE and Regression).

This is the Model:

def model(input, y):
    # Create unit normal priors over the parameters
    cw_prior = dist.Normal(loc=torch.zeros_like(net.layer1[0].weight), scale=0.1*torch.ones_like(net.layer1[0].weight))
    fcw_prior = dist.Normal(loc=torch.zeros_like(net.fc.weight), scale=torch.ones_like(net.fc.weight))
    priors = {'layer1[0].weight': cw_prior, 'fc.weight' : fcw_prior}

    # lift module parameters to random variables sampled from the priors
    lifted_module = pyro.random_module("module", net, priors)

    # sample a model
    lifted_model = lifted_module()

    with pyro.iarange("map", BATCH_SIZE):

        # run the regressor forward conditioned on data
        prediction = lifted_model(input)
        y_hot = to_one_hot(y, n_dims=10)
        pyro.sample("obs", dist.Normal(softmax(prediction), 0.1 * torch.ones(input.size(1)).type_as(input)), obs=y_hot)

This is the Guide:

def guide(input, y):
    # define our variational parameters (no need to specify "requires_grad=True", it will be done automatically
    # once they get passed to the "pyro.param(...)" statements below)
    cw_loc = torch.randn_like(net.layer1[0].weight)
    cw_log_sig = torch.tensor(0.25 * torch.randn_like(net.layer1[0].weight))
    fcw_loc = torch.randn_like(net.fc.weight)
    fcw_log_sig = torch.tensor(0.25 * torch.randn_like(net.fc.weight))

    # register learnable params in the param store
    mw_param = pyro.param("guide_mean_ConvWeight", cw_loc)
    sw_param = softplus(pyro.param("guide_log_scale_ConvWeight", cw_log_sig))
    mb_param = pyro.param("guide_mean_LinearWeight", fcw_loc)
    sb_param = softplus(pyro.param("guide_log_scale_LinearWeight", fcw_log_sig))

    # guide distributions for parameters
    w_dist = dist.Normal(mw_param, sw_param).independent(1)
    b_dist = dist.Normal(mb_param, sb_param).independent(1)
    dists = {'layer1[0].weight': w_dist, 'fc.weight': b_dist}

    # overload the parameters in the module with random samples
    # from the guide distributions
    lifted_module = pyro.random_module("module", net, dists)
    # sample model
    return lifted_module()

And this is the training loop:

net = Net(USE_CUDA).cuda() if USE_CUDA else Net(USE_CUDA)
optim = pyro.optim.SGD({"lr": 0.0001})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

for epoch in range(NUM_EPOCHS):

    print('Epoch: %d --> lr:%f' %(epoch, svi.optim.pt_optim_args['lr']))

    for i, data in enumerate(trainDataset, 0):
        # get the inputs
        inputs, labels_ = data

        if USE_CUDA:
            inputs, labels = torch.tensor(inputs.cuda()), torch.tensor(labels_.cuda())
            inputs, labels = torch.tensor(inputs), torch.tensor(labels_)

        labels = labels.type_as(inputs)

        # calculate the loss and take a gradient step
        loss = svi.step(inputs, labels)

        if i % 100 == 0:
            eval_net = guide(None, None) # IS THIS THE RIGHT WAY OF SAMPLING OUR MODEL?
            total, correct = eval_net.evaulateOnDataset(valDataset) # evaluation on validation set
            val_acc = float(correct) / total
            print('[%d, %5d] loss: %.4f --> Val acc: %.4f' % (epoch + 1, i + 1, np.exp(loss/ float(BATCH_SIZE)), val_acc))

# Now we sample a SINGLE (!!!) model and evaluate it on the testset
eval_net = guide(None, None) # IS THIS THE RIGHT WAY OF SAMPLING OUR MODEL?
total, correct = eval_net.evaulateOnDataset(testDataset)
test_acc = float(correct) / total
print('Accuracy of the network on the %d test images: %.4f %%' % (len(testDataset.dataset.test_labels),100 * test_acc))

And this is the kind of results I’m getting. The accuracy generally goes up for the first few batches but then it drops. I understand that each evaluation is a sample from the Bayesian model so some sort of variation between samples is understandable. I’ve tried both Adam and SGD with different learning rates and I’ve seen the same behavior over and over. Typical results:

Epoch: 0 --> lr:0.000100
[1,     1] loss: 5530719956970225788327559168.0000 --> Val acc: 0.2213
[1,   101] loss: 2.5607 --> Val acc: 0.7367
[1,   201] loss: 0.0491 --> Val acc: 0.6543
[1,   301] loss: 0.0021 --> Val acc: 0.7220
[1,   401] loss: 0.0242 --> Val acc: 0.5723
Epoch: 1 --> lr:0.000100
[2,     1] loss: 0.0026 --> Val acc: 0.6707
[2,   101] loss: 0.0617 --> Val acc: 0.6687
[2,   201] loss: 0.0064 --> Val acc: 0.7117
[2,   301] loss: 0.0628 --> Val acc: 0.6863
[2,   401] loss: 0.0544 --> Val acc: 0.4937
Epoch: 2 --> lr:0.000100
[3,     1] loss: 0.0003 --> Val acc: 0.7400
[3,   101] loss: 0.0074 --> Val acc: 0.6503
[3,   201] loss: 0.0076 --> Val acc: 0.5130
[3,   301] loss: 0.0005 --> Val acc: 0.5467
[3,   401] loss: 0.0068 --> Val acc: 0.4617
Accuracy of the network on the 10000 test images: 43.3333 %

My doubts are the following:

  • Is it correct the way I’m comparing the predictions with the target labels? (last line in the model)
  • Is it correct the way I’m sampling the model when evaluating it on the validation/test data? (i.e. passing “None” arguments to the guide)
  • I find my model quite unstable and different runs result in models that on average reach 25% acc other times the average is in 60% acc. Are my variational parameters too restrictive (first lines in the guide)?

how many parameters are you trying to be bayesian about? unless done with sufficient care, in many settings variational approaches to bayesian neural networks tend to catastrophically fail. one potential problem is the following. the ELBO breaks up into two terms (expected log likelihood minus the kl divergence):


for classification ELL will be of order N, where N is the number of datapoints and KL will be of order the number of weights D you’re being bayesian about. if N ~ D (or even worse D>>N) you will be in a regime where your neural network will be heavily overregularized. basically the KL divergence term prevents you from learning anything with decent accuracies. so if you want to be bayesian on MNIST as a loose rule of thumb you probably don’t want to be bayesian on more than, say, a few thousand parameters (unless you have some super flexible and magical guide at your disposal, but that’s another story).

1 Like

My network doesn’t have many parameters. The example shown in this post has just over 50K weights. It happens to be that my network and training data have almost equal number of elements/points (N~D as you mention in your reply). So maybe this is why it can’t perform better than the results I’m getting?

I’ve modified this model to have far less parameters (~8K) and it seems to be even harder to train. (Even though the standard non Bayesian model reaches the same 93% accuracy) could you elaborate a bit on that “flexible and magical guide” you mentioned?

one issue (there may be more) seems to be your observation likelihood:

pyro.sample("obs", dist.Normal(softmax(prediction), 0.1 * torch.ones(input.size(1)).type_as(input)), obs=y_hot)

you probably want to use a categorical distribution.