Hi,
I am trying to create bayesian CNNs using pyro model-guide framework. I trained the network for few epochs and its taking a lot of time, but the loss is being reduced to something like 50(categorical loss). Please find the code below for the same. Could anyone let me know what am I am doing wrong.
class NN(nn.Module):
def __init__(self):
super(NN, self).__init__()
self.conv1 = nn.Conv2d(1,32, 3)
self.conv2 = nn.Conv2d(32,64,3)
self.dropout=nn.Dropout(0.25)
self.fc1 = nn.Linear(64*12*12, 128)
self.dropout1=nn.Dropout(0.5)
self.out = nn.Linear(128, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
output=F.relu(self.conv1(x))
output = F.max_pool2d(F.relu(self.conv2(output)), (2, 2))
output=self.dropout(output)
# If the size is a square you can only specify a single number
output = output.view(-1, self.num_flat_features(output))
output = F.relu(self.fc1(output))
output=self.dropout(output)
output =self.out(output)
return output
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
# 0 is the batch dimensions
num_features = 1
for s in size:
num_features *= s
return num_features
def model(x_data, y_data):
conv1w_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight))
conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias))
conv2w_prior = Normal(loc=torch.zeros_like(net.conv2.weight), scale=torch.ones_like(net.conv2.weight))
conv2b_prior = Normal(loc=torch.zeros_like(net.conv2.bias), scale=torch.ones_like(net.conv2.bias))
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior, 'conv2.weight': conv2w_prior, 'conv2.bias':
conv2b_prior,'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior,}
lifted_module = pyro.random_module("module", net, priors)
lifted_reg_model = lifted_module()
lhat = log_softmax(lifted_reg_model(x_data))
pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
softplus = torch.nn.Softplus()
def guide(x_data, y_data):
#convolution layer weights
conv1w_mu = torch.randn_like(net.conv1.weight)
conv1w_sigma = torch.randn_like(net.conv1.weight)
conv1w_mu_param = pyro.param("conv1w_mu", conv1w_mu)
conv1w_sigma_param = softplus(pyro.param("conv1w_sigma", conv1w_sigma))
conv1w_prior = Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param)
# First layer bias distribution priors
conv1b_mu = torch.randn_like(net.conv1.bias)
conv1b_sigma = torch.randn_like(net.conv1.bias)
conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu)
conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma))
conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param)
#convolution layer weights
conv2w_mu = torch.randn_like(net.conv2.weight)
conv2w_sigma = torch.randn_like(net.conv2.weight)
conv2w_mu_param = pyro.param("conv2w_mu", conv2w_mu)
conv2w_sigma_param = softplus(pyro.param("conv2w_sigma", conv2w_sigma))
conv2w_prior = Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param)
# First layer bias distribution priors
conv2b_mu = torch.randn_like(net.conv2.bias)
conv2b_sigma = torch.randn_like(net.conv2.bias)
conv2b_mu_param = pyro.param("conv2b_mu", conv2b_mu)
conv2b_sigma_param = softplus(pyro.param("conv2b_sigma", conv2b_sigma))
conv2b_prior = Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param)
# First layer weight distribution priors
fc1w_mu = torch.randn_like(net.fc1.weight)
fc1w_sigma = torch.randn_like(net.fc1.weight)
fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
# First layer bias distribution priors
fc1b_mu = torch.randn_like(net.fc1.bias)
fc1b_sigma = torch.randn_like(net.fc1.bias)
fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
# Output layer weight distribution priors
outw_mu = torch.randn_like(net.out.weight)
outw_sigma = torch.randn_like(net.out.weight)
outw_mu_param = pyro.param("outw_mu", outw_mu)
outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
# Output layer bias distribution priors
outb_mu = torch.randn_like(net.out.bias)
outb_sigma = torch.randn_like(net.out.bias)
outb_mu_param = pyro.param("outb_mu", outb_mu)
outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,'conv2.weight': conv2w_prior, 'conv2.bias':
conv2b_prior,'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior,}
lifted_module = pyro.random_module("module", net, priors)
return lifted_module()
num_iterations = 5
loss = 0
for j in range(num_iterations):
loss = 0
for batch_id, data in enumerate(train_loader):
# calculate the loss and take a gradient step
loss += svi.step(data[0], data[1])
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = loss / normalizer_train
print("Epoch ", j, " Loss ", total_epoch_loss_train)
The prediction accuracy is pretty low with conv-nets. The model is predicting the test as well as training data to be chance.
With one FC bayesian layer, on MNIST, it gives 88% accuracy.