Bayesian LENET 5, bad result, help!


#1

I am trying to implement a lenet 5 for MNIST classification, I get good result for a simpler network, but no matter how I change the annealing, the performance is extremely bad. However, I get better result when I use TensorFlow to implement the bayesian LENET for MNIST with KL annealing, Please someone could help me. I am trying to find the problem with my network for days. the accuracy for this network is 10%, the codes are as following:

class LeNet(nn.Module):
    def __init__(self, num_classes, inputs=1):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(inputs, 6, 5, stride=1)
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

   def forward(self, x):
        out = F.softplus(self.conv1(x))
        out = self.pool1(out)
        out = F.softplus(self.conv2(out))
        out = self.pool2(out)
        out = out.view(out.size(0), -1)
        out = F.softplus(self.fc1(out))
        out = F.softplus(self.fc2(out))
        out = self.fc3(out)
        return out

def conv_normal_prior(name, params):
      mu_param = pyro.param('{}_mu'.format(name), torch.randn_like(params))
      sigma_param = F.softplus(pyro.param('{}_sigma'.format(name), torch.randn_like(params)))
      prior = Normal(loc=mu_param, scale=sigma_param)
      return prior

def dense_normal_prior(name, params):
      mu_param = pyro.param('{}_mu'.format(name), torch.randn_like(params))
      sigma_param = F.softplus(pyro.param("{}_sigma".format(name), torch.randn_like(params)))
      prior = Normal(loc=mu_param, scale=sigma_param)
      return prior

def model(x_data, y_data):
    conv1w_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight))
    conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias))

    conv2w_prior = Normal(loc=torch.zeros_like(net.conv2.weight), scale=torch.ones_like(net.conv2.weight))
    conv2b_prior = Normal(loc=torch.zeros_like(net.conv2.bias), scale=torch.ones_like(net.conv2.bias))

    fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
    fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))

    fc2w_prior = Normal(loc=torch.zeros_like(net.fc2.weight), scale=torch.ones_like(net.fc2.weight))
    fc2b_prior = Normal(loc=torch.zeros_like(net.fc2.bias), scale=torch.ones_like(net.fc2.bias))

    fc3w_prior = Normal(loc=torch.zeros_like(net.fc3.weight), scale=torch.ones_like(net.fc3.weight))
    fc3b_prior = Normal(loc=torch.zeros_like(net.fc3.bias), scale=torch.ones_like(net.fc3.bias))

    priors = {
    'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
    'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior,
    'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,
    'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
    'fc3.weight': fc3w_prior, 'fc3.bias': fc3b_prior
    }

    # lift module parameters to random variables sampled from the priors
    lifted_module = pyro.random_module("module", net, priors)

    # sample a classifier
    lifted_reg_model = lifted_module()

    p_hat = log_softmax(lifted_reg_model(x_data))

    pyro.sample("obs", Categorical(logits=p_hat), obs=y_data)

def guide(x_data, y_data):
    # conv1 weight distribution priors
    conv1w_prior = conv_normal_prior('conv1w', net.conv1.weight)
    # conv1 bias distribution priors
    conv1b_prior = conv_normal_prior('conv1b', net.conv1.bias)

    # conv2 weight distribution priors
    conv2w_prior = conv_normal_prior('conv2w', net.conv2.weight)
    # conv2 bias distribution priors
    conv2b_prior = conv_normal_prior('conv2b', net.conv2.bias)

    # fc1 weight distribution priors
    fc1w_prior = dense_normal_prior('fc1w', net.fc1.weight)
    # fc1 bias distribution priors
    fc1b_prior = dense_normal_prior('fc1b', net.fc1.bias)

    # fc2 weight distribution priors
    fc2w_prior = dense_normal_prior('fc2w', net.fc2.weight)
    # fc2 bias distribution priors
    fc2b_prior = dense_normal_prior('fc2b', net.fc2.bias)

    # fc3 weight distribution priors
    fc3w_prior = dense_normal_prior('fc3w', net.fc3.weight)
    # fc3 bias distribution priors
    fc3b_prior = dense_normal_prior('fc3b', net.fc3.bias)

    priors = {
    'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
    'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior,
    'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,
    'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
    'fc3.weight': fc3w_prior, 'fc3.bias': fc3b_prior
    }

    lifted_module = pyro.random_module("module", net, priors)
    return lifted_module()

def simple_elbo_kl_annealing(model, guide, *args, **kwargs):
    # get the annealing factor and latents to anneal from the keyword
    # arguments passed to the model and guide
    annealing_factor = kwargs.pop('annealing_factor', 1.0)
    latents_to_anneal = kwargs.pop('latents_to_anneal', [])
    # run the guide and replay the model against the guide
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
    model_trace = poutine.trace(
        poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)

    elbo = 0.0
    # loop through all the sample sites in the model and guide trace and
    # construct the loss; note that we scale all the log probabilities of
    # samples sites in `latents_to_anneal` by the factor `annealing_factor`
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            factor = annealing_factor if site["name"].split('$$$')[0] in latents_to_anneal else 1.0
            elbo = elbo + factor * site["fn"].log_prob(site["value"]).sum()
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            factor = annealing_factor if site["name"].split('$$$')[0] in latents_to_anneal else 1.0
            elbo = elbo - factor * site["fn"].log_prob(site["value"]).sum()
    return -elbo

pyro.clear_param_store()
optim = Adam({"lr": 0.005})
svi = SVI(model, guide, optim, loss=simple_elbo_kl_annealing)

import math
num_iterations = 100
loss = 0
# annealing_init = 1/60000
latents_to_anneal = ['module']

for j in range(num_iterations):
    loss = 0
    # calculate annealing factor
   # annealing_factor = annealing_init + j*(1-annealing_init)/num_iterations
    annealing_factor = math.exp(-num_iterations + j + 1)

    for batch_id, data in enumerate(train_loader):
        # calculate the loss and take a gradient step
        loss += svi.step(data[0].view(-1, 1, 28, 28).cuda(), data[1].cuda(),   annealing_factor=annealing_factor,
                    latents_to_anneal=latents_to_anneal)
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
    acc = evaluate(10, test_loader)
    print("Epoch ", j, " Loss ", total_epoch_loss_train, 'Accuracy ', acc)