NUTS for neural networks

Hello everyone!
I am trying to obtain samples from the following neural network posterior, and use it to get uncertainties in predictions:

class NN(nn.Module):
    def __init__(self, input_size, output_size, weights=None):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)
        if weights: # I will describe below why do I need this
            self.fc1.weight.data = nn.Parameter(weights['fc1.weight'])
            self.fc1.bias.data = nn.Parameter(weights['fc1.bias'])
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):
        output = self.fc1(x)
        output = self.softmax(output)
        return output

To do so, I am doing the following:

def model(x_data, y_data):
    fc1w_prior = Normal(loc=torch.zeros_like(model_bayes.fc1.weight), scale=torch.eye(n=num_classes, m=image_size))
    fc1b_prior = Normal(loc=torch.zeros_like(model_bayes.fc1.bias), scale=torch.ones_like(model_bayes.fc1.bias))
    
    priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior}, # 'out.weight': outw_prior, 'out.bias': outb_prior}
    lifted_module = pyro.random_module("module", model_bayes, priors)
    lifted_reg_model = lifted_module()
    lhat = lifted_reg_model(x_data)
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)


num_samples = 500
warmup_steps = 2000

kernel = NUTS(model)
posterior = MCMC(kernel,
                  num_samples=num_samples,
                  warmup_steps=warmup_steps,
                  num_chains=1).run(X_train_b.view(-1, image_size), y_train_b)

But for a reason, NUTS almost in a second goes through the warmup part (but there are 2000 points supposed to be obtained), and then rather fast completes the sampling part. The acceptance rate was always equal to 1.
But obtained results are weak; I sample different random point from posterior via the following code (I sample weights and biases of one point in the posterior space and put them into my network):

NUMBER_OF_ESTIMATIONS = 500

pred = []
for i in range(NUMBER_OF_ESTIMATIONS):
    node = np.random.randint(num_samples)
    weights = {}
    weights['fc1.weight'] = nn.Parameter(posterior.exec_traces[node].nodes['module$$$fc1.weight']['value'])
    weights['fc1.bias'] = nn.Parameter(posterior.exec_traces[node].nodes['module$$$fc1.bias']['value'])
    model_with_weights = NN(28 * 28, 10, weights)
    model_with_weights.eval()
    for data in test_loader:
        X_test, y_test = data
        X_test = X_test.view(-1, 28 * 28)
        predictions = model_with_weights(X_test).detach().numpy()
        pred.append((np.argmax(predictions, axis=-1) == y_test.numpy()).astype(float).mean())
        break

But results I obtained are just like a random decision… (mode of accuracy is approximately 0.1)

So my question is what I am doing wrong?
Is there a built-in method to obtain a prediction for weights sampled from the posterior and add them to a NN just in one line of code?

I have almost no experience with Pyro and ppl at all so that I would appreciate any insights.
Thank you!

Hi @niket096, this is strange to me. I guess there is some problem here. Could you provide a reproducible code (maybe with some synthesis data) so I can take a look?

Hello,
This is the link to google Colab to my Jupyter notebook: Google Colab
But for some reasons on google colab the code (at the moment of NUTS) just got stuck, and even during the night it did not do one warmup step…so I just interrupted it. Despite the fact, I am using the same pyro version on my localhost as it is on the colab.
At the end of the notebook I sent, there is a screenshot of how it works on my computer. For some reasons, results are slightly better than random but pay attention to the execution time of NUTS…

i don’t think you should generally expect NUTS to work on such large neural networks. before going to something so large, i suggest you play around with the much smaller network explored here

and then try incrementally scaling up to larger networks

2 Likes