Hi there,
currently, I am digging deeper into Pyro’s/Pytorch’s memory management, because I have to scale my existing Pyro implementation to really high-dimensional data (genomics data).
However, after running into CUDA memory issues, I wrote code to understand how much GPU memory Pyro is allocating - which seems really high to me.
I attached one simple example, where I model observed counts x
as Poisson(softmax(A*z))
. Hence I estimate both A, z as latent variables.
- When loading all the data onto the GPU (i.e. before the first
svi.step()
call), I occupy 510 MB. - After one
svi.step()
call, I have 868 MB of memory, which stays the same for all iterations.
That means the latent variable A
and z
as well as their product x_hat
must use 358 MB.
The number of elements are:
- (A) 2000 x 12 x 32bit = 0.092 MB
- (z) 12 x 10 x 32bit = 0.47 KB
- (x_hat) 2000 x 10 x 32bit = 0.076 MB
Even if I double the memory, because both location and scale have to estimate, and double footprint again because gradients have to be stored, I reach maybe 1 MB.
To scale the input data, I’d like to understand where the large overhead is coming from?
I’d really appreciate any insights on that. It’s driving me nuts.
EDIT:
- even
with torch.no_grad():
to not compute gradients does not change the high memory footprint
Thanks in advance!
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
import torch.nn as nn
from pyro.infer import SVI, TraceEnum_ELBO
from torch.nn.functional import softmax
from torch.utils.data import DataLoader, TensorDataset
class MyModel(nn.Module):
def __init__(self, num_features=1000, latent_dim=12):
super().__init__()
self.num_features = num_features
self.latent_dim = latent_dim
self.guide = pyro.infer.autoguide.AutoDelta(self.model)
def model(self, x):
plates = self.get_plates(x)
with plates["latents"], plates["features"]:
z = pyro.sample("z", dist.Normal(torch.zeros(1), torch.ones(1)))
x_hat = torch.ones(size=(len(x), self.latent_dim)) @ z
x_hat = softmax(x_hat, dim=-2)
with plates["minibatch"], plates["features"]:
pyro.sample(
"obs_features",
dist.Poisson(x_hat),
obs=x,
)
def get_plates(self, x):
return {
"features": pyro.plate("features", self.num_features, dim=-1),
"latents": pyro.plate("latents", self.latent_dim, dim=-2),
"minibatch": pyro.plate("minibatch", len(x), dim=-2),
}
# CONSTANTS
N_SAMPLES = 200
NUM_FEATURES = 2000
MINIBATCH_SIZE = 10
LATENT_DIM = 12
torch.cuda.set_device(0)
device = torch.device("cuda")
torch.cuda.empty_cache()
pyro.util.set_rng_seed(0)
pyro.clear_param_store()
# Create random poisson data
data = torch.randint(0, 20, size=(N_SAMPLES, NUM_FEATURES), dtype=torch.float32)
dataloader = DataLoader(
TensorDataset(data), batch_size=MINIBATCH_SIZE, shuffle=False, pin_memory=True
)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
# Create Model
mymodel = MyModel(
num_features=NUM_FEATURES,
latent_dim=LATENT_DIM,
).cuda()
# Run SVI
optimizer = pyro.optim.Adam({"lr": 0.01})
elbo = TraceEnum_ELBO(strict_enumeration_warning=False)
svi = SVI(mymodel.model, mymodel.guide, optimizer, elbo)
for epoch in range(100):
for x in dataloader:
x = x[0].cuda()
with poutine.scale(scale=1 / MINIBATCH_SIZE):
loss = svi.step(x=x)