Apologies in advance if this type of question is frowned upon. I tried to look for a bit via the search function for BNN related issues, but I did not find anything specific enough.
I’m currently following this guide for making a BNN.
Running through the code, the network learns and is accurate for the MNIST data set as used in the notebook. Aside from a few annoying things with the code (i.e. variables used in functions declared outside of functions…), I thought it made a good deal of sense.
So, I tried to modify it to work with my dataset. I am quite literally just substituting out the dataset, changing a few of the network sizes to fit my data, and from what I can tell that’s it. And yet, the network doesn’t appear to have good accuracy. What’s most mysterious, is that it seems to improve on loss, but when it comes time to evaluate accuracy via predictions… Nope.
note: Small readability issues stem from as described earlier with variables declared outside of functions then used within-function. I.e., the ‘net’ in the guide function is actually a global parameter. So when the predict function calls the guide function, the net declared at the start is being used.
The following code gives the output:
Started
Epoch 0 Loss 118.34769998327617
Epoch 1 Loss 28.149548526412868
Epoch 2 Loss 17.477245631752957
Epoch 3 Loss 17.329567421994327
Epoch 4 Loss 15.06282152032439
Epoch 5 Loss 14.35670173609586
Epoch 6 Loss 15.07527992447901
Epoch 7 Loss 15.750993160267619
Epoch 8 Loss 14.207093004361747
Epoch 9 Loss 14.708512778656106
Validation Accuracy
accuracy: 21 %
Training accuracy
accuracy: 20 %
Could it really just be a hyper parameter issue? I find it hard to believe because I even set my network to be like 1000 nodes and 500 epochs. Just to see if it could memorize my training set. And it cannot. The # of input features is not even enormous and there is a very strong signal in this data. A vanilla network achieves ~85% accuracy which is very good for this type of data. There’s also a big chance I just fundamentally don’t understand Pyro yet, or at least, how to use it.
class NN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(NN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, x):
output = self.fc1(x)
output = F.relu(output)
output = self.out(output)
return output
# This is my debugging purposes, I just set up a switch that can toggle between MNIST data set and
# mine
MNIST = False
if not MNIST:
net = NN(22, 10, 5)
else:
net = NN(28*28, 1024, 10)
log_softmax = nn.LogSoftmax(dim=1)
def model(x_data, y_data):
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
# sample a regressor (which also samples w and b)
lifted_reg_model = lifted_module()
lhat = log_softmax(lifted_reg_model(x_data))
pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
softplus = torch.nn.Softplus()
def guide(x_data, y_data):
# First layer weight distribution priors
fc1w_mu = torch.randn_like(net.fc1.weight)
fc1w_sigma = torch.randn_like(net.fc1.weight)
fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
# First layer bias distribution priors
fc1b_mu = torch.randn_like(net.fc1.bias)
fc1b_sigma = torch.randn_like(net.fc1.bias)
fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
# Output layer weight distribution priors
outw_mu = torch.randn_like(net.out.weight)
outw_sigma = torch.randn_like(net.out.weight)
outw_mu_param = pyro.param("outw_mu", outw_mu)
outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
# Output layer bias distribution priors
outb_mu = torch.randn_like(net.out.bias)
outb_sigma = torch.randn_like(net.out.bias)
outb_mu_param = pyro.param("outb_mu", outb_mu)
outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
lifted_module = pyro.random_module("module", net, priors)
return lifted_module()
optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
if not MNIST:
num_iterations = 10
loss = 0
print('Started')
for j in range(num_iterations):
loss = 0
for batch_id, data in enumerate(train_loader):
# calculate the loss and take a gradient step
X = data[0]
Y = data[1]
loss += svi.step(X, Y)
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = loss / normalizer_train
print("Epoch ", j, " Loss ", total_epoch_loss_train)
num_samples = 10
def predict(x):
sampled_models = [guide(None, None) for _ in range(num_samples)]
yhats = [model(x).data for model in sampled_models]
mean = torch.mean(torch.stack(yhats), 0)
return np.argmax(mean.numpy(), axis=1)
print('Validation Accuracy')
correct = 0
total = 0
for j, data in enumerate(validation_loader):
X, labels = data
predicted = predict(X)
for i in range(len(predicted)):
if predicted[i] == labels[i]:
correct += 1
total += labels.size(0)
print("accuracy: %d %%" % (100 * correct / total))
print('Training accuracy')
correct = 0
total = 0
for j, data in enumerate(train_loader):
X, labels = data
predicted = predict(X)
for i in range(len(predicted)):
if predicted[i] == labels[i]:
correct += 1
total += labels.size(0)
print("accuracy: %d %%" % (100 * correct / total))