I am trying to implement a lenet 5 for MNIST classification, I get good result for a simpler network, but no matter how I change the annealing, the performance is extremely bad. However, I get better result when I use TensorFlow to implement the bayesian LENET for MNIST with KL annealing, Please someone could help me. I am trying to find the problem with my network for days. the accuracy for this network is 10%, the codes are as following:
class LeNet(nn.Module):
def __init__(self, num_classes, inputs=1):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(inputs, 6, 5, stride=1)
self.conv2 = nn.Conv2d(6, 16, 5, stride=1)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(256, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
def forward(self, x):
out = F.softplus(self.conv1(x))
out = self.pool1(out)
out = F.softplus(self.conv2(out))
out = self.pool2(out)
out = out.view(out.size(0), -1)
out = F.softplus(self.fc1(out))
out = F.softplus(self.fc2(out))
out = self.fc3(out)
return out
def conv_normal_prior(name, params):
mu_param = pyro.param('{}_mu'.format(name), torch.randn_like(params))
sigma_param = F.softplus(pyro.param('{}_sigma'.format(name), torch.randn_like(params)))
prior = Normal(loc=mu_param, scale=sigma_param)
return prior
def dense_normal_prior(name, params):
mu_param = pyro.param('{}_mu'.format(name), torch.randn_like(params))
sigma_param = F.softplus(pyro.param("{}_sigma".format(name), torch.randn_like(params)))
prior = Normal(loc=mu_param, scale=sigma_param)
return prior
def model(x_data, y_data):
conv1w_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight))
conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias))
conv2w_prior = Normal(loc=torch.zeros_like(net.conv2.weight), scale=torch.ones_like(net.conv2.weight))
conv2b_prior = Normal(loc=torch.zeros_like(net.conv2.bias), scale=torch.ones_like(net.conv2.bias))
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
fc2w_prior = Normal(loc=torch.zeros_like(net.fc2.weight), scale=torch.ones_like(net.fc2.weight))
fc2b_prior = Normal(loc=torch.zeros_like(net.fc2.bias), scale=torch.ones_like(net.fc2.bias))
fc3w_prior = Normal(loc=torch.zeros_like(net.fc3.weight), scale=torch.ones_like(net.fc3.weight))
fc3b_prior = Normal(loc=torch.zeros_like(net.fc3.bias), scale=torch.ones_like(net.fc3.bias))
priors = {
'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior,
'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,
'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
'fc3.weight': fc3w_prior, 'fc3.bias': fc3b_prior
}
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
# sample a classifier
lifted_reg_model = lifted_module()
p_hat = log_softmax(lifted_reg_model(x_data))
pyro.sample("obs", Categorical(logits=p_hat), obs=y_data)
def guide(x_data, y_data):
# conv1 weight distribution priors
conv1w_prior = conv_normal_prior('conv1w', net.conv1.weight)
# conv1 bias distribution priors
conv1b_prior = conv_normal_prior('conv1b', net.conv1.bias)
# conv2 weight distribution priors
conv2w_prior = conv_normal_prior('conv2w', net.conv2.weight)
# conv2 bias distribution priors
conv2b_prior = conv_normal_prior('conv2b', net.conv2.bias)
# fc1 weight distribution priors
fc1w_prior = dense_normal_prior('fc1w', net.fc1.weight)
# fc1 bias distribution priors
fc1b_prior = dense_normal_prior('fc1b', net.fc1.bias)
# fc2 weight distribution priors
fc2w_prior = dense_normal_prior('fc2w', net.fc2.weight)
# fc2 bias distribution priors
fc2b_prior = dense_normal_prior('fc2b', net.fc2.bias)
# fc3 weight distribution priors
fc3w_prior = dense_normal_prior('fc3w', net.fc3.weight)
# fc3 bias distribution priors
fc3b_prior = dense_normal_prior('fc3b', net.fc3.bias)
priors = {
'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior,
'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,
'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,
'fc3.weight': fc3w_prior, 'fc3.bias': fc3b_prior
}
lifted_module = pyro.random_module("module", net, priors)
return lifted_module()
def simple_elbo_kl_annealing(model, guide, *args, **kwargs):
# get the annealing factor and latents to anneal from the keyword
# arguments passed to the model and guide
annealing_factor = kwargs.pop('annealing_factor', 1.0)
latents_to_anneal = kwargs.pop('latents_to_anneal', [])
# run the guide and replay the model against the guide
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(
poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
elbo = 0.0
# loop through all the sample sites in the model and guide trace and
# construct the loss; note that we scale all the log probabilities of
# samples sites in `latents_to_anneal` by the factor `annealing_factor`
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
factor = annealing_factor if site["name"].split('$$$')[0] in latents_to_anneal else 1.0
elbo = elbo + factor * site["fn"].log_prob(site["value"]).sum()
for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
factor = annealing_factor if site["name"].split('$$$')[0] in latents_to_anneal else 1.0
elbo = elbo - factor * site["fn"].log_prob(site["value"]).sum()
return -elbo
pyro.clear_param_store()
optim = Adam({"lr": 0.005})
svi = SVI(model, guide, optim, loss=simple_elbo_kl_annealing)
import math
num_iterations = 100
loss = 0
# annealing_init = 1/60000
latents_to_anneal = ['module']
for j in range(num_iterations):
loss = 0
# calculate annealing factor
# annealing_factor = annealing_init + j*(1-annealing_init)/num_iterations
annealing_factor = math.exp(-num_iterations + j + 1)
for batch_id, data in enumerate(train_loader):
# calculate the loss and take a gradient step
loss += svi.step(data[0].view(-1, 1, 28, 28).cuda(), data[1].cuda(), annealing_factor=annealing_factor,
latents_to_anneal=latents_to_anneal)
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = loss / normalizer_train
acc = evaluate(10, test_loader)
print("Epoch ", j, " Loss ", total_epoch_loss_train, 'Accuracy ', acc)