BNN isn't learning

Hi!

I’m trying to train a BNN to take a 1D convolved signal as a input and output the original signal.

My BNN looks like this:

class BNN(PyroModule):

    def __init__(self, h1, h2):
        super().__init__()
        self.fc1 = PyroModule[nn.Linear](h1, h2)
        self.fc1.weight = PyroSample(dist.Normal(0, 0.5).expand([h2, h1]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(0., 1.).expand([h2]).to_event(1))
        
        #self.fc1 = nn.Linear(h1, h2)

        self.fc2 = PyroModule[nn.Linear](h1, h2)
        self.fc2.weight = PyroSample(dist.Cauchy(0, 0.5).expand([h2, h2]).to_event(2))
        self.fc2.bias = PyroSample(dist.Normal(0., 1.).expand([h2]).to_event(1))

        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        
        x = self.relu(self.fc1(x))
        mu = self.relu(self.fc2(x))
        
        sigma = pyro.sample("sigma", dist.Uniform(0., 0.05))
    
        with pyro.plate("data", y.shape[1]):
            obs = pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
        
        return mu

and I’m training it like this:

bnn_model = BNN(h1=N, h2=n)

# Set Pyro random seed
pyro.set_rng_seed(42)

nuts_kernel = pyro.infer.NUTS(bnn_model, jit_compile=True)

# Define MCMC sampler, get 50 posterior samples
bnn_mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=20)

# Convert data to PyTorch tensors
x_train = torch.from_numpy(y_data).float()
y_train = torch.from_numpy(dataset).float()
# Run MCMC
bnn_mcmc.run(x_train, y_train)
# Get predictions
predictive = pyro.infer.Predictive(model=bnn_model, posterior_samples=bnn_mcmc.get_samples())
preds = predictive(x_train[0,:])

the shapes of x_train and y_train are (400,100) and (400, 200). They contain 400 artificially generated step function signals.

My problem is that the model doesn’t seem to learn anything. I tried to do this also with a standard neural network just to see if it is even possible and was able to get some ok predictions. However the BNN just gives a constant noise as a output as can be seen in the attached plot. I have taken the mean of the samples gotten into the preds variable.

Do you have any idea on what would be the problem here? The longest I’ve trained it is with 500 samples ~1.5h so I think the problem isn’t that I’m not training it enough.

image

MCMC will only give you samples of sigma. You will want to Poutine (Effect handlers) — Pyro documentation your BNN parameters if you want to obtain their posterior samples. It might be better to use SVI here I guess.

Hi!
Thanks for help, I switched to SVI and used poutine and now the bnn works well