Vectorized particles for linear model

Hello,

I am trying to implement a neural network where the last layer is Bayesian, so essentially a deterministic feature extractor followed by a Bayesian linear model. I would like this model to be vectorizable so I can use several elbo samples / particles efficiently. I don’t think PyTorch’s nn.Linear is compatible in this regard, and I therefore created a modified version BatchLinear that I think should work with batches of weights / biases. But I do think something is wrong as the elbo loss increases with the number of vectorized particles. For now I am using AutoMultivariateNormal as a guide but will later be using normalizing flows.

I have the following questions:

  1. What am I doing wrong?
  2. When I use several vectorized particles, the shape of the weight and bias is [num_particles, 1, out_features, in_features] and [num_particles, 1, out_features] but I would have expected [num_particles, out_features, in_features] and [num_particles, out_features], i.e. a singleton dimension is added. Why?
  3. Is there a way I can trace the model with vectorized particles (similar to trace = poutine.trace(model).get_trace(x, y);trace.compute_log_prob();print(trace.format_shapes()))
  4. I think maybe the batch shape of the observation distribution becomes [num_particles, num_particles, N] instead of [num_particles, 1, N] but I am not really sure.

I would really appreciate some help/guidance as I have been stuck with this for a couple of days now.

Below is a (maybe not so) minimal working example:

import pyro
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from copy import deepcopy
from pyro.optim import Adam
from pyro.distributions import Normal
from pyro.infer import SVI, Trace_ELBO
from pyro.nn import PyroModule, PyroSample
from pyro.infer.autoguide import AutoMultivariateNormal


class RegressionModel(nn.Module):
    """Simple neural network regression model with one hidden layer"""

    def __init__(self, n_features, noise_var=1, device='cuda'):
        super().__init__()
        self.n_features = n_features
        self.noise_var = noise_var
        self.device = device

        self.fc_1 = nn.Linear(in_features=self.n_features, out_features=25)
        self.last_layer = nn.Linear(in_features=25, out_features=1)

        self.noise_scale = torch.sqrt(
            torch.tensor(self.noise_var, dtype=torch.float, device=self.device)
        )
        self.to(self.device)


    def forward(self, x):
        x = self.fc_1(x)
        x = F.relu(x)
        x = self.last_layer(x)

        return x


    def loss(self, model_output, target):
        log_likelihood = Normal(
            loc=model_output, scale=self.noise_scale).log_prob(target).sum()
        return -log_likelihood


    def optimizer(self, lr=1e-3, weight_decay=0):
        return optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)


class BatchLinear(nn.Linear):
    """Linear layer that (hopefully) supports vectorized elbo samples."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    

    def forward(self, input): 
        return input @ self.weight.transpose(-2,-1) + self.bias.unsqueeze(-2)    


class BayesianLastLayer:
    """Neural network with Bayesian last layer"""
    def __init__(self, base_model, prior_var=1, noise_var=1):
        self.base_model = base_model
        self.prior_var = prior_var
        self.noise_var = noise_var
        self.device = base_model.device

        self.prior_loc = torch.tensor(0, dtype=torch.float, device=self.device)
        self.prior_scale = torch.sqrt(
            torch.tensor(prior_var, dtype=torch.float, device=self.device)
        )

        # Replace last layer with a Bayesian equivalent that also supports
        # vectorized particles
        bayesian_last_layer = PyroModule[BatchLinear](
            in_features=base_model.last_layer.in_features,
            out_features=base_model.last_layer.out_features
        )
        bayesian_last_layer.weight = PyroSample(
            Normal(self.prior_loc, self.prior_scale).expand(
                base_model.last_layer.weight.shape).to_event(
                    base_model.last_layer.weight.dim())
        )
        bayesian_last_layer.bias = PyroSample(
            Normal(self.prior_loc, self.prior_scale).expand(
                base_model.last_layer.bias.shape).to_event(
                    base_model.last_layer.bias.dim())
        )

        self.base_model.last_layer = bayesian_last_layer
        
        self.noise_scale = torch.sqrt(
            torch.tensor(self.noise_var, dtype=torch.float, device=self.device)
        )


    def model(self, x, y=None): 
        model_output = self.base_model(x).squeeze(-1)
        with pyro.plate('data', size=len(x)):
            obs = pyro.sample(
                    'obs',
                    Normal(loc=model_output, scale=self.noise_scale),
                    obs=y
                )
            return model_output


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    pyro.set_rng_seed(seed)


def train_bayesian_last_layer(X, y, model, num_particles, vectorize_particles):
    pyro.clear_param_store()
    print(f'Num particles: {num_particles}; vectorized: {vectorize_particles}')
    set_seed(0)

    # Make last layer Bayesian
    model = deepcopy(model)
    model_bll = BayesianLastLayer(model)
    guide = AutoMultivariateNormal(model_bll.model)

    elbo = Trace_ELBO(
        num_particles=num_particles, vectorize_particles=vectorize_particles
    )
    svi = SVI(
        model_bll.model,
        guide,
        Adam({"lr": 1e-4}),
        elbo
    )

    n_svi_steps = 1000
    for step in range(n_svi_steps):
        loss = svi.step(X, y)
        if step % 100 == 0:
            print(f'SVI step: {step}, Loss: {loss}')
    print('\n')

# Data
set_seed(0)
device = 'cuda'
N, M = 1000, 15
X = torch.randn(N, M, device=device)
w = torch.randn(M, 1, device=device)
y = X @ w + torch.randn(N, 1, device=device)
y = y.squeeze(-1)
print(X.shape, y.shape)

# Train initial neural network
n_train_steps = 1000
model = RegressionModel(n_features=M, device=device)
optimizer = model.optimizer()
for step in range(n_train_steps):
    model_output = model(X)
    loss = model.loss(model_output, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f'Train step: {step}; Loss: {loss.detach().item()}')
print('\n')

# Seems to work
train_bayesian_last_layer(
    X, y, model, num_particles=1, vectorize_particles=False
)

# Not sure if this works
train_bayesian_last_layer(
    X, y, model, num_particles=1, vectorize_particles=True
)

# Seems to work
train_bayesian_last_layer(
    X, y, model, num_particles=7, vectorize_particles=False
)

# Does not seem to work, loss increases a lot
train_bayesian_last_layer(
    X, y, model, num_particles=7, vectorize_particles=True
)

instead of doing this from scratch i suggest you see if tyxe can support this out of the box for you

1 Like