Predictive model memory skyrockets

I am trying to generate a Bayesian Neural Net with Pyro, and feel as though I’m getting close. I can train and predict on a simple test dataset, however after training on my actual data ( aprox 1000 features) I cannot run through the Predictive model due to memory limitations. Even if I take the model off the GPU and attempt using the cpu the process eats up more than 125GiB of ram and will crash.

I suspect I am probably doing something very wrong in the model/guide architecture which might be causing this. Here is what I have, and would appreciate anyone pointing out where I’ve gone wrong…

# 
# A basic Linear Layer for Pyro, setup to try and learn the loc, and scale of Normally distributed weights and bias
# 
class PyroLinear(torch.nn.Linear, pyro.nn.PyroModule): 
    def __init__(self, in_features, out_features, device='cpu', **kwargs):
        super().__init__(in_features, out_features, **kwargs)
        mu = torch.randn_like( self.weight, device=device)
        sigma = torch.rand_like( self.weight, device=device) + 0.01
        self.weight = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features, self.in_features]).to_event(2))
        if self.bias is not None:
            mu = torch.randn_like( self.bias, device=device)
            sigma = torch.rand_like( self.bias, device=device) + 0.01
            self.bias = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features]).to_event(1))
            
    @property
    def device(self):
        if len( list(self.parameters())) > 0:
            return next(self.parameters()).device
        else:
            return 'cpu'

# 
# Helper function to create a PyroLinear Layer with activation function
# 
def layer_block( n_in, n_out, device='cpu'):
    return pyro.nn.PyroModule[ torch.nn.Sequential](
        PyroLinear(n_in, n_out, device=device),
        torch.nn.Tanh()
    )

# 
# My full model architecture with model/guide
# 
class NNet_Model(pyro.nn.PyroModule):
    def __init__(self, n_inputs=1,  h_layers=[20], scale=0.1, device='cpu'):
        super().__init__()
        self.predictive = None
        self.norm_scale = scale
        self.layer_sizes = [n_inputs, *h_layers]

        layer_blocks = [ layer_block(in_f, out_f, device=device) for in_f, out_f in zip(self.layer_sizes[:-1], self.layer_sizes[1:])]
        
        self.feature_net = pyro.nn.PyroModule[ torch.nn.Sequential](*layer_blocks)
        self.out = PyroLinear( self.layer_sizes[ -1], 2, device=device)
        
    #
    # The latent model function used to train with SVI
    #
    def model_base( self, x, y=None):
        x = self.feature_net( x)
        mu = self.out( x)
        prob = torch.softmax( mu, axis=1)
        pred = (prob[:,1] - prob[:,0]) * 0.5 + 0.5
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample( "obs", pyro.distributions.Normal( loc=pred, scale=torch.tensor(self.norm_scale).to( self.device)), obs=y).type( torch.float32)

    #
    # The guide function used in training with SVI
    #
    def guide(self, x, y=None):
        for i in range(0, len( self.feature_net)):
            mu = torch.randn_like( self.feature_net[i][0].weight)
            sigma = torch.rand_like( self.feature_net[i][0].weight) + 0.1
            mu_param = pyro.param( f"feature_net.{i}.0.w_mu", mu)
            sigma_param = pyro.param( f"feature_net.{i}.0.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
            _ = pyro.sample( f"feature_net.{i}.0.weight",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].weight.size(0), self.feature_net[i][0].weight.size(1)]).to_event(2))
            mu = torch.randn_like( self.feature_net[i][0].bias)
            sigma = torch.rand_like( self.feature_net[i][0].bias) + 0.1
            mu_param = pyro.param( f"feature_net.{i}.0.b_mu", mu)
            sigma_param = pyro.param( f"feature_net.{i}.0.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
            _ = pyro.sample( f"feature_net.{i}.0.bias",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].bias.size(0)]).to_event(1))
        mu = torch.randn_like( self.out.weight)
        sigma = torch.rand_like( self.out.weight) + 0.1
        mu_param = pyro.param( f"out.w_mu", mu)
        sigma_param = pyro.param( f"out.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
        _ = pyro.sample( f"out.weight",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.weight.size(0), self.out.weight.size(1)]).to_event(2))
        mu = torch.randn_like( self.out.bias)
        sigma = torch.rand_like( self.out.bias) + 0.1
        mu_param = pyro.param( f"out.b_mu", mu)
        sigma_param = pyro.param( f"out.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
        _ = pyro.sample( f"out.bias",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.bias.size(0)]).to_event(1))
    
    # Return the predictions along with the std
    #    - This creates the Predictive model if it does not yet exist.
    #    - It then assembles the final mean prediction from "obs" and returns it along with the std
    def forward(self, x, samples=100):
        if self.predictive == None:
            self.predictive = pyro.infer.Predictive(self.model_base, guide=self.guide, num_samples=samples)

        result = self.predictive( x)
        y_pred = result[ 'obs'].mean( axis=0)
        y_std = result[ 'obs'].std( axis=0)
        return y_pred.view( x.shape[0], -1), y_std.view( x.shape[0], -1)

    @property
    def device(self):
        return self.out.weight.device
    
    def to( self, device):
        super().to( device)
        for i in range(0, len( self.feature_net)):
            self.feature_net[i][0].weight = self.feature_net[i][0].weight.to( device)
            self.feature_net[i][0].bias = self.feature_net[i][0].bias.to( device)
        self.out.weight = self.out.weight.to( device)
        self.out.bias = self.out.bias.to( device)

I am also rather confused as to why this seems to work (on a small dataset) even though my model_base() and guide() do not actually return anything. I guess this is where Pyro does all it’s work under the hood, which makes it difficult to follow what is actually going on.

I have switched to using an AutoGuide with my BNN model, however with a large NN (5000 or so nodes in hidden layer) and trying to get a Predictive sample of 1000 quickly eats up over 80GiB of ram. This seems pretty crazy. I only need the final prediction, not the full set of 1000 per sample, so how do I go about assigning the parameters correctly to the model so that I can simply run a prediction as many times as I want accumulating the mean on the fly?

1 Like

Hello, thank you very much for the nice library!

I have implemented the ProdLDA model (link to the Pyro tutorial) with 31000 legal documents for a big Public Administration ¶.

The main reason is to sort a huge mass of PDF documents accumulated along the years. The goal is to provide probabilistic tags (i.e. “theta” in LDA terminology) for each document. This should help the PA in searching for similar documents and use them as templates for future documentations.

Since it was missing in the original tutorial, I have added to the ProdLDA model the following lines in order to compute the “theta” (in percentages) for each document.

# sub-sample prodLDA's results by posterior predictive
predictive = Predictive(model=prodLDA.model, 
                        guide=prodLDA.guide, 
                        num_samples=2000,
                        return_sites=["logtheta"])
samples = predictive(docs)

# extract "theta" (percentages)
theta_percentages_numpy = torch.nn.functional.softmax(torch.mean(samples.get("logtheta").cpu(), dim=0), dim=-1).numpy()

I am able to run the whole pipeline successfully up to 25000 documents on a single GPU (16 GBs). As a reference, when I run 25000 documents “svi.step()” consumes 3-4 GBs for 2-3 Hrs at 60-70 % “Volatile GPU-Util” in “nvidia-smi” and the Predictive() consumes 12-13 GBs at ~90 % “Volatile GPU-Util” and it takes 2-3 minutes. Results are great!

Above 25000 documents the Predictive() step triggers an Out of Memory (OOM) error. I have noticed that when I am using 31000 documents the “svi.step()” runs smoothly at around 4-5 GBs for 3-4 Hrs. That’s way I can always get the “beta” (namely the wordcloud), no matter of the size of the data-set. Nevertheless I do need the “theta”, as well. Unfortunately, with 31000 documents Predictive() breaks the whole pipeline because the GPU’s memory starts consuming 3-4 times more GPU memory within few minutes.

That’s unfortunate, since it looks like a waste of computational power: the final three minutes of computation spoils the other 3-4 Hrs of computation…

As I have understood from the documentation, Predictive() is just a huge plate() on top of the actual model. Hence explained the OOM error.

Below the strategies I have tried -unsuccessfully- in order to avoid the OOM with 31000 documents:

  1. decreasing the “batch_size” (e.g. from 512 to 16). It works, but now “svi.step()” is extremely slow (24 Hrs). Indeed, I consider it a suboptimal strategy because during “svi.step()” the GPU is working only at 10-20 % of its potential, leaving most of the GPU “Volatile GPU-Util” idling during the 24 Hrs “svi.step()”.

  2. I have tried to move the “Predictive()” to CPU. It works but only single thread: it takes 8 Hrs to run the Predictive() whereas by GPU took 3 minutes! The “parallel=True” triggers an error, as mentioned in this post (link here). It looks like the problem is the “batch_size” inherited when I have saved the trained “prodLDA” model.
    Ideal solution: it would have been wonderful to use “batch_size = 512” for “svi.step()”, whereas a “batch_size = 1” for “Predictive()”. This would have used GPU at its fullest potential. Unfortunately, this strategy is not possible/easy.

  3. I was thinking that the problem was the cumulating GPU memory from “svi.step()” to “Predictive()”. I had saved the ProdLDA model (not an easy task itself…), deleted the model and made a hard GPU flushing by “gc.collect()” and “torch.cuda.empty_cache()” before entering “Predictive()”. Still the problem persists. Therefore, I do not think anymore that the problem is the cumulating GPU memory between “svi.step()” and “Predictive()” steps.

  4. Use multiple GPU by “horovod”. Not tried it yet.

Honestly, I am running out of ideas how to solve the OOM error with Predictive().

In conclusion, it looks like the culprit of GPU’s OOM error is Predictive().

I will reply to my own question. I hope it helps somebody else stuck in the same position as mine.

After several attempts, I came out with a temporary solution for computing “theta”/tags for a ProdLDA model without getting OOM error with GPU.

The key idea was to replace Predictive() with AutoDelta(). This way I can compute “theta” on GPU even for 31000 documents without errors.

# import 
	from pyro.infer.autoguide import AutoDelta
	from pyro.infer import SVI, TraceMeanField_ELBO
	
	# initizialize Autoguide
	inferred_guide = AutoDelta(prodLDA.model)

	# inizialize predictive posterior SVI
	optimizer = pyro.optim.Adam({"lr": learning_rate})
	svi_posterior = SVI(model=prodLDA.model, 
          guide=inferred_guide,
          optim=optimizer,  
          loss=TraceMeanField_ELBO(num_particles=1))
        
    # set iterations
    num_samples_post_posterior = 2000
          
        # then the usual SVI training... but now with previously trained ProdLDA model
	bar = trange(num_samples_post_posterior)
	for epoch in bar:
          
    		# ELBO step
    		loss = svi_posterior.step(docs)

	    	bar.set_postfix(epoch_loss="{:.2e}".format(loss))
# extract "theta" 
theta_numpy = inferred_guide().get("logtheta").cpu().detach().numpy()

By using AutoDelta() the computation is still 10 X slower (30 min) than Predictive() (3-4 minutes). Anyway, I am more than happy that within 30 minutes I can now compute the “theta” for my 31000 documents without running into an OOM error of the GPU.

For the record, I have also tried to replace AutoDelta() with either AutoNormal() and AutoLowRankMultivariateNormal(). Nevertheless, with the other two options the GPU’s memory blew. The only option which worked for me was AutoDelta().

Lastly, AutoDelta() also worked without any error even by setting num_particles=10. Very good!

I am using:
pyro-ppl==1.8.6
torch==1.13.1

I’am using IBM Power 8 machine with Ubuntu 20.04.2. The architeture is “ppc64le” which makes it hard to install more recent PyTorch versions.