Write custom guide based on PyroModule for convolutional layers

Hi everyone,
I’m quite new on Pyro framework. I’m trying to understand its functionalities building some basic bayesian neural networks.

I successfully built a basic classifier for the MNIST dataset, writing the model and the custom guide as follows:

class BNN(PyroModule):
    def __init__(self, input_size, hidden_size, output_size, prior_scale):
        super(BNN, self).__init__()
        self.activation = nn.ReLU()
        self.output = nn.LogSoftmax(dim=1)
        self.fc1 = PyroModule[nn.Linear](input_size**2, hidden_size)
        self.fc2 = PyroModule[nn.Linear](hidden_size, output_size)
        
        # Set layer parameters as random variables
        self.fc1.weight = PyroSample(dist.Normal(
            torch.tensor(0., device=device),
            torch.tensor(prior_scale, device=device)
        ).expand([hidden_size, input_size**2]).to_event(2))
        self.fc1.bias = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([hidden_size]).to_event(1))
        self.fc2.weight = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([output_size, hidden_size]).to_event(2))
        self.fc2.bias = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([output_size]).to_event(1))
        
    def forward(self, x, y=None):
        x = x.reshape(-1, 28*28)
        output = self.activation(self.fc1(x))
        output = self.fc2(output)
        output = self.output(output)
        with pyro.plate("data", x.shape[0]):
            pyro.sample("obs", Categorical(logits=output).to_event(1), obs=y)
        return output
softplus = torch.nn.Softplus()

def guide(x_data, y_data):
    
    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight, device=device)
    fc1w_sigma = torch.randn_like(net.fc1.weight, device=device)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w = pyro.sample("fc1.weight", Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param).to_event(2))
    
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(net.fc1.bias, device=device)
    fc1b_sigma = torch.randn_like(net.fc1.bias, device=device)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b = pyro.sample("fc1.bias", Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param).to_event(1))
    
    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.fc2.weight, device=device)
    outw_sigma = torch.randn_like(net.fc2.weight, device=device)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    fc2w = pyro.sample("fc2.weight", Normal(loc=outw_mu_param, scale=outw_sigma_param).to_event(2))
    
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.fc2.bias, device=device)
    outb_sigma = torch.randn_like(net.fc2.bias, device=device)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    fc2b = pyro.sample("fc2.bias", Normal(loc=outb_mu_param, scale=outb_sigma_param).to_event(1))

I was capable to train the model using SVI and doing inference on some test images.

I would like to modify the model, and thus the guide, adding a couple of convolutional layers. I understand that due to the higher number of parameters this might not be feasible for very big fully-convolutional models but I would like to understand how to implement them.

Can anyone show me how to define the weight and biases for the conv2d layers?

I tied this way:

class BNN(PyroModule):
    def __init__(self, input_size, hidden_size, output_size, prior_scale):
        super(BNN, self).__init__()
        self.activation = nn.ReLU()
        self.output = nn.LogSoftmax(dim=1)
        self.conv1 = PyroModule[nn.Conv2d](1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        # Calculate the input size for the linear layers after convolution and pooling
        conv_output_size = input_size // 2 // 2  # Assuming you use two max pooling layers with size 2x2
        linear_input_size = 64 * conv_output_size * conv_output_size

        self.fc1 = PyroModule[nn.Linear](linear_input_size, hidden_size)
        self.fc2 = PyroModule[nn.Linear](hidden_size, output_size)

        # Set convolutional layer parameters as random variables
        self.conv1.weight = PyroSample(dist.Normal(
            torch.tensor(0., device=device),
            torch.tensor(prior_scale, device=device)
        ).expand([32, 1, 3, 3]).to_event(4))
        self.conv1.bias = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([32]).to_event(1))
        self.conv2.weight = PyroSample(dist.Normal(
            torch.tensor(0., device=device),
            torch.tensor(prior_scale, device=device)
        ).expand([64, 32, 3, 3]).to_event(4))
        self.conv2.bias = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([64]).to_event(1))

        # Set linear layer parameters as random variables
        self.fc1.weight = PyroSample(dist.Normal(
            torch.tensor(0., device=device),
            torch.tensor(prior_scale, device=device)
        ).expand([hidden_size, linear_input_size]).to_event(2))
        self.fc1.bias = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([hidden_size]).to_event(1))
        self.fc2.weight = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([output_size, hidden_size]).to_event(2))
        self.fc2.bias = PyroSample(
            dist.Normal(torch.tensor(0., device=device),
                        torch.tensor(prior_scale, device=device)
        ).expand([output_size]).to_event(1))

    def forward(self, x, y=None):
        x = x.view(-1, 1, 28, 28)
        output = self.pool(self.activation(self.conv1(x)))
        output = self.pool(self.activation(self.conv2(output)))
        output = output.view(output.size(0), -1)  # Flatten the tensor for linear layers
        output = self.activation(self.fc1(output))
        output = self.fc2(output)
        output = self.output(output)
        with pyro.plate("data", x.shape[0]):
            pyro.sample("obs", Categorical(logits=output).to_event(1), obs=y)
        return output
softplus = torch.nn.Softplus()

def guide(x_data, y_data):
    # First convolutional layer weight distribution prior
    conv1w_mu = torch.randn_like(net.conv1.weight, device=device)
    conv1w_sigma = torch.randn_like(net.conv1.weight, device=device)
    conv1w_mu_param = pyro.param("conv1w_mu", conv1w_mu)
    conv1w_sigma_param = softplus(pyro.param("conv1w_sigma", conv1w_sigma))
    conv1w = pyro.sample("conv1.weight", Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param).to_event(4))
    
    # First convolutonal layer bias distribution priors
    conv1b_mu = torch.randn_like(net.conv1.bias, device=device)
    conv1b_sigma = torch.randn_like(net.conv1.bias, device=device)
    conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu)
    conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma))
    conv1b = pyro.sample("conv1.bias", Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param).to_event(1))
    
    # Second convolutional layer weight distribution prior
    conv2w_mu = torch.randn_like(net.conv2.weight, device=device)
    conv2w_sigma = torch.randn_like(net.conv2.weight, device=device)
    conv2w_mu_param = pyro.param("conv2w_mu", conv2w_mu)
    conv2w_sigma_param = softplus(pyro.param("conv2w_sigma", conv2w_sigma))
    conv2w = pyro.sample("conv2.weight", Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param).to_event(4))
    
    # Second convolutonal layer bias distribution priors
    conv2b_mu = torch.randn_like(net.conv2.bias, device=device)
    conv2b_sigma = torch.randn_like(net.conv2.bias, device=device)
    conv2b_mu_param = pyro.param("conv2b_mu", conv2b_mu)
    conv2b_sigma_param = softplus(pyro.param("conv2b_sigma", conv2b_sigma))
    conv2b = pyro.sample("conv2.bias", Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param).to_event(1))
    
    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight, device=device)
    fc1w_sigma = torch.randn_like(net.fc1.weight, device=device)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w = pyro.sample("fc1.weight", Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param).to_event(2))
    
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(net.fc1.bias, device=device)
    fc1b_sigma = torch.randn_like(net.fc1.bias, device=device)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b = pyro.sample("fc1.bias", Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param).to_event(1))
    
    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.fc2.weight, device=device)
    outw_sigma = torch.randn_like(net.fc2.weight, device=device)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    fc2w = pyro.sample("fc2.weight", Normal(loc=outw_mu_param, scale=outw_sigma_param).to_event(2))
    
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.fc2.bias, device=device)
    outb_sigma = torch.randn_like(net.fc2.bias, device=device)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    fc2b = pyro.sample("fc2.bias", Normal(loc=outb_mu_param, scale=outb_sigma_param).to_event(1))

What do you think about that?

Thanks for your help