Why is Pyro consuming so much memory?

Hi there,
currently, I am digging deeper into Pyro’s/Pytorch’s memory management, because I have to scale my existing Pyro implementation to really high-dimensional data (genomics data).
However, after running into CUDA memory issues, I wrote code to understand how much GPU memory Pyro is allocating - which seems really high to me.
I attached one simple example, where I model observed counts x as Poisson(softmax(A*z)). Hence I estimate both A, z as latent variables.

  • When loading all the data onto the GPU (i.e. before the first svi.step() call), I occupy 510 MB.
  • After one svi.step() call, I have 868 MB of memory, which stays the same for all iterations.

That means the latent variable A and z as well as their product x_hat must use 358 MB.
The number of elements are:

  • (A) 2000 x 12 x 32bit = 0.092 MB
  • (z) 12 x 10 x 32bit = 0.47 KB
  • (x_hat) 2000 x 10 x 32bit = 0.076 MB

Even if I double the memory, because both location and scale have to estimate, and double footprint again because gradients have to be stored, I reach maybe 1 MB.

To scale the input data, I’d like to understand where the large overhead is coming from?
I’d really appreciate any insights on that. It’s driving me nuts.

EDIT:

  • even with torch.no_grad(): to not compute gradients does not change the high memory footprint

Thanks in advance!

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
import torch.nn as nn
from pyro.infer import SVI, TraceEnum_ELBO
from torch.nn.functional import softmax
from torch.utils.data import DataLoader, TensorDataset


class MyModel(nn.Module):
    def __init__(self, num_features=1000, latent_dim=12):
        super().__init__()
        self.num_features = num_features
        self.latent_dim = latent_dim
        self.guide = pyro.infer.autoguide.AutoDelta(self.model)

    def model(self, x):
        plates = self.get_plates(x)

        with plates["latents"], plates["features"]:
            z = pyro.sample("z", dist.Normal(torch.zeros(1), torch.ones(1)))

        x_hat = torch.ones(size=(len(x), self.latent_dim)) @ z
        x_hat = softmax(x_hat, dim=-2)

        with plates["minibatch"], plates["features"]:
            pyro.sample(
                "obs_features",
                dist.Poisson(x_hat),
                obs=x,
            )

    def get_plates(self, x):
        return {
            "features": pyro.plate("features", self.num_features, dim=-1),
            "latents": pyro.plate("latents", self.latent_dim, dim=-2),
            "minibatch": pyro.plate("minibatch", len(x), dim=-2),
        }

# CONSTANTS
N_SAMPLES = 200
NUM_FEATURES = 2000
MINIBATCH_SIZE = 10
LATENT_DIM = 12

torch.cuda.set_device(0)
device = torch.device("cuda")
torch.cuda.empty_cache()
pyro.util.set_rng_seed(0)
pyro.clear_param_store()

# Create random poisson data
data = torch.randint(0, 20, size=(N_SAMPLES, NUM_FEATURES), dtype=torch.float32)
dataloader = DataLoader(
    TensorDataset(data), batch_size=MINIBATCH_SIZE, shuffle=False, pin_memory=True
)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

# Create Model
mymodel = MyModel(
    num_features=NUM_FEATURES,
    latent_dim=LATENT_DIM,
).cuda()

# Run SVI
optimizer = pyro.optim.Adam({"lr": 0.01})
elbo = TraceEnum_ELBO(strict_enumeration_warning=False)
svi = SVI(mymodel.model, mymodel.guide, optimizer, elbo)

for epoch in range(100):
    for x in dataloader:
        x = x[0].cuda()

        with poutine.scale(scale=1 / MINIBATCH_SIZE):
            loss = svi.step(x=x)

didn’t look at your model in detail but it’s probably the einsum. the torch.einsum implentation isn’t always particularly efficient. you might try replacing with e.g. torch.inner

1 Like

Thanks for the reply!

During debugging I already tried different approaches:

  • torch.einsum
  • torch.inner
  • A @ z

which all produce the same GPU memory consumption.

EDIT: I simplified the example a bit more. Memory is still at 880 MB, although I only estimate a single matrix z.

afaik gpu memory allocation can be a bit non-intuitive and you might just be seeing reserved-but-not-used memory see e.g. this