BNN - Loss vs Accuracy?

Apologies in advance if this type of question is frowned upon. I tried to look for a bit via the search function for BNN related issues, but I did not find anything specific enough.

I’m currently following this guide for making a BNN.

Running through the code, the network learns and is accurate for the MNIST data set as used in the notebook. Aside from a few annoying things with the code (i.e. variables used in functions declared outside of functions…), I thought it made a good deal of sense.

So, I tried to modify it to work with my dataset. I am quite literally just substituting out the dataset, changing a few of the network sizes to fit my data, and from what I can tell that’s it. And yet, the network doesn’t appear to have good accuracy. What’s most mysterious, is that it seems to improve on loss, but when it comes time to evaluate accuracy via predictions… Nope.

note: Small readability issues stem from as described earlier with variables declared outside of functions then used within-function. I.e., the ‘net’ in the guide function is actually a global parameter. So when the predict function calls the guide function, the net declared at the start is being used.

The following code gives the output:
Epoch 0 Loss 118.34769998327617
Epoch 1 Loss 28.149548526412868
Epoch 2 Loss 17.477245631752957
Epoch 3 Loss 17.329567421994327
Epoch 4 Loss 15.06282152032439
Epoch 5 Loss 14.35670173609586
Epoch 6 Loss 15.07527992447901
Epoch 7 Loss 15.750993160267619
Epoch 8 Loss 14.207093004361747
Epoch 9 Loss 14.708512778656106
Validation Accuracy
accuracy: 21 %
Training accuracy
accuracy: 20 %

Could it really just be a hyper parameter issue? I find it hard to believe because I even set my network to be like 1000 nodes and 500 epochs. Just to see if it could memorize my training set. And it cannot. The # of input features is not even enormous and there is a very strong signal in this data. A vanilla network achieves ~85% accuracy which is very good for this type of data. There’s also a big chance I just fundamentally don’t understand Pyro yet, or at least, how to use it.

class NN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        output = self.fc1(x)
        output = F.relu(output)
        output = self.out(output)
        return output
# This is my debugging purposes, I just set up a switch that can toggle between MNIST data set and 
# mine
MNIST = False
if not MNIST:
    net = NN(22, 10, 5)
    net = NN(28*28, 1024, 10)

log_softmax = nn.LogSoftmax(dim=1)

def model(x_data, y_data):

    fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
    fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))

    outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
    outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))

    priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'out.weight': outw_prior, 'out.bias': outb_prior}

    # lift module parameters to random variables sampled from the priors
    lifted_module = pyro.random_module("module", net, priors)
    # sample a regressor (which also samples w and b)
    lifted_reg_model = lifted_module()
    lhat = log_softmax(lifted_reg_model(x_data))
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

softplus = torch.nn.Softplus()
def guide(x_data, y_data):
    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}

    lifted_module = pyro.random_module("module", net, priors)

    return lifted_module()

optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

if not MNIST:
    num_iterations = 10
    loss = 0

    for j in range(num_iterations):
        loss = 0
        for batch_id, data in enumerate(train_loader):
            # calculate the loss and take a gradient step
            X = data[0]
            Y = data[1]
            loss += svi.step(X, Y)
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = loss / normalizer_train

        print("Epoch ", j, " Loss ", total_epoch_loss_train)

    num_samples = 10
    def predict(x):
        sampled_models = [guide(None, None) for _ in range(num_samples)]
        yhats = [model(x).data for model in sampled_models]
        mean = torch.mean(torch.stack(yhats), 0)
        return np.argmax(mean.numpy(), axis=1)

    print('Validation Accuracy')
    correct = 0
    total = 0
    for j, data in enumerate(validation_loader):
        X, labels = data
        predicted = predict(X)
        for i in range(len(predicted)):
            if predicted[i] == labels[i]:
                correct += 1
        total += labels.size(0)
    print("accuracy: %d %%" % (100 * correct / total))

    print('Training accuracy')
    correct = 0
    total = 0
    for j, data in enumerate(train_loader):
        X, labels = data
        predicted = predict(X)
        for i in range(len(predicted)):
            if predicted[i] == labels[i]:
                correct += 1
        total += labels.size(0)
    print("accuracy: %d %%" % (100 * correct / total))

this seems like a data-specific issue. these things are generally sensitive to prior initialization, and there’s no reason to expect the same network to work on completely different datasets. i’d recommend playing around with the initializations and also other tricks to control variance, eg local reparameterization.

Hmm. Okay. Could I ask why these models are very sensitive to prior initialization? I’m not the most experienced nor skilled in this type of models.

Do these types of networks need input features to be a certain format or scale?

afaict that guide is incorrect so i wouldn’t suggest that you base your code on it.

i haven’t vetted it but this is probably a better place to look.

in general, bayesian neural networks are an active area of research and may not be a great entry point for non-experts. (also, they generally give poor empirical performance—at least unless you apply all kinds of expert tweaks and even then the performance is generally poor)

I also got the problem like this, actually I became a worse result, the accuracy of MNIST is 10%, which means the network didn’t work.
But I just made a copy from the Github website, In order to test the code all the parameters were not changed. Although the probability means the uncertainty, It shouldn’t have a totally different result, right?
Could you be so kind to have a look for that code. Or could you please offer us a tutorial or example because the Bayesian Neural Network is an important topic in PPL. (I think you are one of the developers of the Pyro Team, right?)