Help with simple classification BNN


#1

Hi all,
I made a very simple dataset, with input z,x and output is either class 1 or 0.

To keep the problem easy, I used x is always equal to 1. Then z is sampled from a normal dis (0,1) and if z < 0, output class 1 and if z > 0, output class 0.

The model and guide I used are below (taken from an IBM example and altered)

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.l1 = torch.nn.Linear(nx, nh)
        # self.l1a = torch.nn.Linear(nh, nh)
        self.l2 = torch.nn.Linear(nh, ny)
        self.relu = torch.nn.ReLU()


    def forward(self, x):
        h = self.relu(self.l1(torch.Tensor(x).view((-1, nx))))
        # h = self.relu(self.l1(h.view((-1, nh))))
        yhat = self.l2(h)
        return yhat

mlp = MLP().to(device)


# Model

def normal(*shape):
    loc = torch.zeros(*shape).to(device)
    scale = torch.ones(*shape).to(device)
    return Normal(loc, scale)

def model(imgs, lbls):
    priors = {
        'l1.weight': normal(nh, nx), 'l1.bias': normal(nh),
        'l2.weight': normal(ny, nh), 'l2.bias': normal(ny)}
    lifted_module = pyro.random_module("mlp", mlp, priors)
    lifted_reg_model = lifted_module()
    lhat = log_softmax(lifted_reg_model(imgs))
    pyro.sample("obs", Categorical(logits=lhat), obs=lbls)


# Inference Guide

def vnormal(name, *shape):
    loc = pyro.param(name+"m", torch.randn(*shape, requires_grad=True, device=device))
    scale = pyro.param(name+"s", torch.randn(*shape, requires_grad=True, device=device))
    return Normal(loc, softplus(scale))

def guide(imgs, lbls):
    dists = {
        'l1.weight': vnormal("W1", nh, nx), 'l1.bias': vnormal("b1", nh),
        'l2.weight': vnormal("W2", ny, nh), 'l2.bias':vnormal("b2", ny)}
    lifted_module = pyro.random_module("mlp", mlp, dists)
    return lifted_module()

inference = SVI(model, guide, Adam({"lr": 0.001}), loss=Trace_ELBO())

Testing this on minibatches of 20, I seem to achieve poorer performance as the number of epochs goes up.

Could anybody explain why this is?