Dealing with noise in Bayesian neural network regression

I am trying to perform Bayesian neural network regression on a 1-D synthetic dataset, when defining the model, at first I used the known fixed noise level, everything goes fine, and then I defined a Gamma prior to the precision to see if I can get good posterior noise estimation, and then

  1. The Trace_ELBO loss value is much better than using fixed known precision

  2. The prediction and the quality of uncertainty get much worse

So I want to ask,

  1. Why does a better ELBO gives much worse prediction?

  2. What is a proper way to deal with unknown precision?

Below is my code, Bayesian neural network with fixed known precision:

import torch
import torch.nn as nn
import pyro
import pyro.optim
import numpy as np
from   torch.distributions import constraints
    
noise_level = 0.01
num_data    = 1000
xs          = torch.linspace(-3, 3, num_data).reshape(num_data, 1)
ys          = torch.FloatTensor(np.sinc(xs.numpy()))
ys          = ys + noise_level * torch.randn_like(ys)

class BNN_SVI:
    def __init__(self, dim, conf = dict()):
        self.dim           = dim
        self.num_iters     = conf.get('num_iters',    400)
        self.print_every   = conf.get('print_every',  100)
        self.batch_size    = conf.get('batch_size',   128)
        self.lr            = conf.get('lr',           1e-3)
        self.weight_prior  = conf.get('weight_prior', 1.0)
        self.bias_prior    = conf.get('bias_prior',   1.0)
        self.prec_alpha    = conf.get('prec_alpha', 3)
        self.prec_beta     = conf.get('prec_alpha', 1)
        self.nn = nn.Sequential(
                nn.Linear(self.dim, 50), nn.Tanh(), 
                nn.Linear(50, 1))

    def model(self, X, y):
        noise_scale = torch.tensor(noise_level)
        # precision   = pyro.sample("precision", pyro.distributions.Gamma(self.prec_alpha, self.prec_beta))
        # noise_scale = 1 / precision.sqrt()
        num_x       = X.shape[0]
        priors      = dict()
        for n, p in self.nn.named_parameters():
            if "weight" in n:
                priors[n] = pyro.distributions.Normal(
                        loc   = torch.zeros_like(p),
                        scale = torch.ones_like(p)).to_event(1)
            elif "bias" in n:
                priors[n] = pyro.distributions.Normal(
                        loc   = torch.zeros_like(p),
                        scale = torch.ones_like(p)).to_event(1)

        lifted_module    = pyro.random_module("module", self.nn, priors)
        lifted_reg_model = lifted_module()
        with pyro.plate("map", len(X), subsample_size = min(num_x, self.batch_size)) as ind:
            prediction_mean = lifted_reg_model(X[ind]).squeeze(-1)
            pyro.sample("obs", 
                    pyro.distributions.Normal(prediction_mean, noise_scale), 
                    obs = y[ind])

    def guide(self, X, y):
        softplus  = nn.Softplus()
        # alpha     = pyro.param("alpha", torch.tensor(self.prec_alpha), constraint = constraints.positive)
        # beta      = pyro.param("beta",  torch.tensor(self.prec_beta),  constraint = constraints.positive)
        # precision = pyro.sample("precision", pyro.distributions.Gamma(alpha, beta))
        priors      = dict()
        for n, p in self.nn.named_parameters():
            if "weight" in n:
                loc   = pyro.param("mu_"    + n, self.weight_prior * torch.randn_like(p))
                scale = pyro.param("sigma_" + n, softplus(torch.randn_like(p)), constraint = constraints.positive)
                priors[n] = pyro.distributions.Normal(loc = loc, scale = scale).to_event(1)
            elif "bias" in n:
                loc       = pyro.param("mu_"    + n, self.bias_prior * torch.randn_like(p))
                scale     = pyro.param("sigma_" + n, softplus(torch.randn_like(p)), constraint = constraints.positive)
                priors[n] = pyro.distributions.Normal(loc = loc, scale = scale).to_event(1)
        lifted_module = pyro.random_module("module", self.nn, priors)
        return lifted_module()
            
    def train(self, X, y):
        num_train   = X.shape[0]
        y           = y.reshape(num_train)
        self.x_mean = X.mean(dim = 0)
        self.x_std  = X.std(dim = 0)
        self.y_mean = y.mean()
        self.y_std  = y.std()
        self.X      = (X - self.x_mean) / self.x_std
        self.y      = (y - self.y_mean) / self.y_std
        optim       = pyro.optim.Adam({"lr":self.lr})
        svi         = pyro.infer.SVI(self.model, self.guide, optim, loss = pyro.infer.Trace_ELBO())
        pyro.clear_param_store()
        self.rec = []
        for i in range(self.num_iters):
            loss = svi.step(self.X, self.y)
            self.rec.append(loss / num_train)
            if (i+1) % self.print_every == 0:
                print("[Iteration %05d] loss: %.4f" % (i + 1, loss / num_train))
    
    def sample(self):
        net = self.guide(self.X, self.y)
        return net

    def sample_predict(self, nn, x):
        return nn((x - self.x_mean) / self.x_std) * self.y_std + self.y_mean



conf                 = dict()
conf['num_iters']    = 5000
conf['batch_size']   = 32
conf['print_every']  = 50
conf['weight_priro'] = 1.
conf['bias_priro']   = 1.
conf['lr']           = 1e-1
conf['prec_alpha']   = 10. # precision mean = 20 variance = 40
conf['prec_beta']    = 0.5
model                = BNN_SVI(dim = 1, conf = conf)

num_train = 20
train_id  = torch.randperm(num_data)[:num_train]
train_x   = xs[train_id]
train_y   = ys[train_id]
model.train(train_x, train_y)

The prediction and the loss functions are plotted in the below figure, as can be seen, the prediction is pretty reasonable, while the final ELBO loss values are at the level of 1e2

However, if we replace noise_scale = torch.tensor(noise_level) in model function with

precision   = pyro.sample("precision", pyro.distributions.Gamma(self.prec_alpha, self.prec_beta))
noise_scale = 1 / precision.sqrt()

And add corresponding parameter definition in guide function:

alpha     = pyro.param("alpha", torch.tensor(self.prec_alpha), constraint = constraints.positive)
beta      = pyro.param("beta",  torch.tensor(self.prec_beta),  constraint = constraints.positive)
precision = pyro.sample("precision", pyro.distributions.Gamma(alpha, beta))

The prediction becomes this:

We can see that the ELBO loss goes below 10 very quickly.

Maybe it should be `conf.get(‘prec_beta’, 1)?

Thanks, but even though having that bug fixed, the Gamma prior distribution still makes the prediction worse (and the ELBO value better), as is shown in the below figure

Hi @Alaya, I think that with the same mean prediction, the smaller noise the smaller likelihood prob, hence the larger ELBO you will get. You put prior around 20, so noise starts around 0.2 (>0.01) hence you’ll start with smaller ELBO than the first model. When you run, your SVI will think that the larger noise the better (because it gives smaller ELBO) so it will tend to increase the noise.

I think that it is the reason why your model does not give a good prediction. If you control your precision prior (e.g. with a lower bound) then when running, SVI will decrease precision until some balanced value. After that, other parameters will start to learn (if KL of precision does not dominate the ELBO).

1 Like