Hi everyone,
I’m quite new on Pyro framework. I’m trying to understand its functionalities building some basic bayesian neural networks.
I successfully built a basic classifier for the MNIST dataset, writing the model and the custom guide as follows:
class BNN(PyroModule):
def __init__(self, input_size, hidden_size, output_size, prior_scale):
super(BNN, self).__init__()
self.activation = nn.ReLU()
self.output = nn.LogSoftmax(dim=1)
self.fc1 = PyroModule[nn.Linear](input_size**2, hidden_size)
self.fc2 = PyroModule[nn.Linear](hidden_size, output_size)
# Set layer parameters as random variables
self.fc1.weight = PyroSample(dist.Normal(
torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([hidden_size, input_size**2]).to_event(2))
self.fc1.bias = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([hidden_size]).to_event(1))
self.fc2.weight = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([output_size, hidden_size]).to_event(2))
self.fc2.bias = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([output_size]).to_event(1))
def forward(self, x, y=None):
x = x.reshape(-1, 28*28)
output = self.activation(self.fc1(x))
output = self.fc2(output)
output = self.output(output)
with pyro.plate("data", x.shape[0]):
pyro.sample("obs", Categorical(logits=output).to_event(1), obs=y)
return output
softplus = torch.nn.Softplus()
def guide(x_data, y_data):
# First layer weight distribution priors
fc1w_mu = torch.randn_like(net.fc1.weight, device=device)
fc1w_sigma = torch.randn_like(net.fc1.weight, device=device)
fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
fc1w = pyro.sample("fc1.weight", Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param).to_event(2))
# First layer bias distribution priors
fc1b_mu = torch.randn_like(net.fc1.bias, device=device)
fc1b_sigma = torch.randn_like(net.fc1.bias, device=device)
fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
fc1b = pyro.sample("fc1.bias", Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param).to_event(1))
# Output layer weight distribution priors
outw_mu = torch.randn_like(net.fc2.weight, device=device)
outw_sigma = torch.randn_like(net.fc2.weight, device=device)
outw_mu_param = pyro.param("outw_mu", outw_mu)
outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
fc2w = pyro.sample("fc2.weight", Normal(loc=outw_mu_param, scale=outw_sigma_param).to_event(2))
# Output layer bias distribution priors
outb_mu = torch.randn_like(net.fc2.bias, device=device)
outb_sigma = torch.randn_like(net.fc2.bias, device=device)
outb_mu_param = pyro.param("outb_mu", outb_mu)
outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
fc2b = pyro.sample("fc2.bias", Normal(loc=outb_mu_param, scale=outb_sigma_param).to_event(1))
I was capable to train the model using SVI and doing inference on some test images.
I would like to modify the model, and thus the guide, adding a couple of convolutional layers. I understand that due to the higher number of parameters this might not be feasible for very big fully-convolutional models but I would like to understand how to implement them.
Can anyone show me how to define the weight and biases for the conv2d layers?
I tied this way:
class BNN(PyroModule):
def __init__(self, input_size, hidden_size, output_size, prior_scale):
super(BNN, self).__init__()
self.activation = nn.ReLU()
self.output = nn.LogSoftmax(dim=1)
self.conv1 = PyroModule[nn.Conv2d](1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
# Calculate the input size for the linear layers after convolution and pooling
conv_output_size = input_size // 2 // 2 # Assuming you use two max pooling layers with size 2x2
linear_input_size = 64 * conv_output_size * conv_output_size
self.fc1 = PyroModule[nn.Linear](linear_input_size, hidden_size)
self.fc2 = PyroModule[nn.Linear](hidden_size, output_size)
# Set convolutional layer parameters as random variables
self.conv1.weight = PyroSample(dist.Normal(
torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([32, 1, 3, 3]).to_event(4))
self.conv1.bias = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([32]).to_event(1))
self.conv2.weight = PyroSample(dist.Normal(
torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([64, 32, 3, 3]).to_event(4))
self.conv2.bias = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([64]).to_event(1))
# Set linear layer parameters as random variables
self.fc1.weight = PyroSample(dist.Normal(
torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([hidden_size, linear_input_size]).to_event(2))
self.fc1.bias = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([hidden_size]).to_event(1))
self.fc2.weight = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([output_size, hidden_size]).to_event(2))
self.fc2.bias = PyroSample(
dist.Normal(torch.tensor(0., device=device),
torch.tensor(prior_scale, device=device)
).expand([output_size]).to_event(1))
def forward(self, x, y=None):
x = x.view(-1, 1, 28, 28)
output = self.pool(self.activation(self.conv1(x)))
output = self.pool(self.activation(self.conv2(output)))
output = output.view(output.size(0), -1) # Flatten the tensor for linear layers
output = self.activation(self.fc1(output))
output = self.fc2(output)
output = self.output(output)
with pyro.plate("data", x.shape[0]):
pyro.sample("obs", Categorical(logits=output).to_event(1), obs=y)
return output
softplus = torch.nn.Softplus()
def guide(x_data, y_data):
# First convolutional layer weight distribution prior
conv1w_mu = torch.randn_like(net.conv1.weight, device=device)
conv1w_sigma = torch.randn_like(net.conv1.weight, device=device)
conv1w_mu_param = pyro.param("conv1w_mu", conv1w_mu)
conv1w_sigma_param = softplus(pyro.param("conv1w_sigma", conv1w_sigma))
conv1w = pyro.sample("conv1.weight", Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param).to_event(4))
# First convolutonal layer bias distribution priors
conv1b_mu = torch.randn_like(net.conv1.bias, device=device)
conv1b_sigma = torch.randn_like(net.conv1.bias, device=device)
conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu)
conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma))
conv1b = pyro.sample("conv1.bias", Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param).to_event(1))
# Second convolutional layer weight distribution prior
conv2w_mu = torch.randn_like(net.conv2.weight, device=device)
conv2w_sigma = torch.randn_like(net.conv2.weight, device=device)
conv2w_mu_param = pyro.param("conv2w_mu", conv2w_mu)
conv2w_sigma_param = softplus(pyro.param("conv2w_sigma", conv2w_sigma))
conv2w = pyro.sample("conv2.weight", Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param).to_event(4))
# Second convolutonal layer bias distribution priors
conv2b_mu = torch.randn_like(net.conv2.bias, device=device)
conv2b_sigma = torch.randn_like(net.conv2.bias, device=device)
conv2b_mu_param = pyro.param("conv2b_mu", conv2b_mu)
conv2b_sigma_param = softplus(pyro.param("conv2b_sigma", conv2b_sigma))
conv2b = pyro.sample("conv2.bias", Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param).to_event(1))
# First layer weight distribution priors
fc1w_mu = torch.randn_like(net.fc1.weight, device=device)
fc1w_sigma = torch.randn_like(net.fc1.weight, device=device)
fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
fc1w = pyro.sample("fc1.weight", Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param).to_event(2))
# First layer bias distribution priors
fc1b_mu = torch.randn_like(net.fc1.bias, device=device)
fc1b_sigma = torch.randn_like(net.fc1.bias, device=device)
fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
fc1b = pyro.sample("fc1.bias", Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param).to_event(1))
# Output layer weight distribution priors
outw_mu = torch.randn_like(net.fc2.weight, device=device)
outw_sigma = torch.randn_like(net.fc2.weight, device=device)
outw_mu_param = pyro.param("outw_mu", outw_mu)
outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
fc2w = pyro.sample("fc2.weight", Normal(loc=outw_mu_param, scale=outw_sigma_param).to_event(2))
# Output layer bias distribution priors
outb_mu = torch.randn_like(net.fc2.bias, device=device)
outb_sigma = torch.randn_like(net.fc2.bias, device=device)
outb_mu_param = pyro.param("outb_mu", outb_mu)
outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
fc2b = pyro.sample("fc2.bias", Normal(loc=outb_mu_param, scale=outb_sigma_param).to_event(1))
What do you think about that?
Thanks for your help