Predictive model memory skyrockets

I am trying to generate a Bayesian Neural Net with Pyro, and feel as though I’m getting close. I can train and predict on a simple test dataset, however after training on my actual data ( aprox 1000 features) I cannot run through the Predictive model due to memory limitations. Even if I take the model off the GPU and attempt using the cpu the process eats up more than 125GiB of ram and will crash.

I suspect I am probably doing something very wrong in the model/guide architecture which might be causing this. Here is what I have, and would appreciate anyone pointing out where I’ve gone wrong…

# 
# A basic Linear Layer for Pyro, setup to try and learn the loc, and scale of Normally distributed weights and bias
# 
class PyroLinear(torch.nn.Linear, pyro.nn.PyroModule): 
    def __init__(self, in_features, out_features, device='cpu', **kwargs):
        super().__init__(in_features, out_features, **kwargs)
        mu = torch.randn_like( self.weight, device=device)
        sigma = torch.rand_like( self.weight, device=device) + 0.01
        self.weight = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features, self.in_features]).to_event(2))
        if self.bias is not None:
            mu = torch.randn_like( self.bias, device=device)
            sigma = torch.rand_like( self.bias, device=device) + 0.01
            self.bias = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features]).to_event(1))
            
    @property
    def device(self):
        if len( list(self.parameters())) > 0:
            return next(self.parameters()).device
        else:
            return 'cpu'

# 
# Helper function to create a PyroLinear Layer with activation function
# 
def layer_block( n_in, n_out, device='cpu'):
    return pyro.nn.PyroModule[ torch.nn.Sequential](
        PyroLinear(n_in, n_out, device=device),
        torch.nn.Tanh()
    )

# 
# My full model architecture with model/guide
# 
class NNet_Model(pyro.nn.PyroModule):
    def __init__(self, n_inputs=1,  h_layers=[20], scale=0.1, device='cpu'):
        super().__init__()
        self.predictive = None
        self.norm_scale = scale
        self.layer_sizes = [n_inputs, *h_layers]

        layer_blocks = [ layer_block(in_f, out_f, device=device) for in_f, out_f in zip(self.layer_sizes[:-1], self.layer_sizes[1:])]
        
        self.feature_net = pyro.nn.PyroModule[ torch.nn.Sequential](*layer_blocks)
        self.out = PyroLinear( self.layer_sizes[ -1], 2, device=device)
        
    #
    # The latent model function used to train with SVI
    #
    def model_base( self, x, y=None):
        x = self.feature_net( x)
        mu = self.out( x)
        prob = torch.softmax( mu, axis=1)
        pred = (prob[:,1] - prob[:,0]) * 0.5 + 0.5
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample( "obs", pyro.distributions.Normal( loc=pred, scale=torch.tensor(self.norm_scale).to( self.device)), obs=y).type( torch.float32)

    #
    # The guide function used in training with SVI
    #
    def guide(self, x, y=None):
        for i in range(0, len( self.feature_net)):
            mu = torch.randn_like( self.feature_net[i][0].weight)
            sigma = torch.rand_like( self.feature_net[i][0].weight) + 0.1
            mu_param = pyro.param( f"feature_net.{i}.0.w_mu", mu)
            sigma_param = pyro.param( f"feature_net.{i}.0.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
            _ = pyro.sample( f"feature_net.{i}.0.weight",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].weight.size(0), self.feature_net[i][0].weight.size(1)]).to_event(2))
            mu = torch.randn_like( self.feature_net[i][0].bias)
            sigma = torch.rand_like( self.feature_net[i][0].bias) + 0.1
            mu_param = pyro.param( f"feature_net.{i}.0.b_mu", mu)
            sigma_param = pyro.param( f"feature_net.{i}.0.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
            _ = pyro.sample( f"feature_net.{i}.0.bias",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].bias.size(0)]).to_event(1))
        mu = torch.randn_like( self.out.weight)
        sigma = torch.rand_like( self.out.weight) + 0.1
        mu_param = pyro.param( f"out.w_mu", mu)
        sigma_param = pyro.param( f"out.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
        _ = pyro.sample( f"out.weight",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.weight.size(0), self.out.weight.size(1)]).to_event(2))
        mu = torch.randn_like( self.out.bias)
        sigma = torch.rand_like( self.out.bias) + 0.1
        mu_param = pyro.param( f"out.b_mu", mu)
        sigma_param = pyro.param( f"out.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
        _ = pyro.sample( f"out.bias",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.bias.size(0)]).to_event(1))
    
    # Return the predictions along with the std
    #    - This creates the Predictive model if it does not yet exist.
    #    - It then assembles the final mean prediction from "obs" and returns it along with the std
    def forward(self, x, samples=100):
        if self.predictive == None:
            self.predictive = pyro.infer.Predictive(self.model_base, guide=self.guide, num_samples=samples)

        result = self.predictive( x)
        y_pred = result[ 'obs'].mean( axis=0)
        y_std = result[ 'obs'].std( axis=0)
        return y_pred.view( x.shape[0], -1), y_std.view( x.shape[0], -1)

    @property
    def device(self):
        return self.out.weight.device
    
    def to( self, device):
        super().to( device)
        for i in range(0, len( self.feature_net)):
            self.feature_net[i][0].weight = self.feature_net[i][0].weight.to( device)
            self.feature_net[i][0].bias = self.feature_net[i][0].bias.to( device)
        self.out.weight = self.out.weight.to( device)
        self.out.bias = self.out.bias.to( device)

I am also rather confused as to why this seems to work (on a small dataset) even though my model_base() and guide() do not actually return anything. I guess this is where Pyro does all it’s work under the hood, which makes it difficult to follow what is actually going on.

I have switched to using an AutoGuide with my BNN model, however with a large NN (5000 or so nodes in hidden layer) and trying to get a Predictive sample of 1000 quickly eats up over 80GiB of ram. This seems pretty crazy. I only need the final prediction, not the full set of 1000 per sample, so how do I go about assigning the parameters correctly to the model so that I can simply run a prediction as many times as I want accumulating the mean on the fly?

1 Like