Hi there,

currently, I am digging deeper into Pyro’s/Pytorch’s memory management, because I have to scale my existing Pyro implementation to really high-dimensional data (genomics data).

However, after running into CUDA memory issues, I wrote code to understand how much GPU memory Pyro is allocating - which seems really high to me.

I attached one simple example, where I model observed counts `x`

as `Poisson(softmax(A*z))`

. Hence I estimate both A, z as latent variables.

- When loading all the data onto the GPU (i.e. before the first
`svi.step()`

call), I occupy 510 MB. - After one
`svi.step()`

call, I have 868 MB of memory, which stays the same for all iterations.

That means the latent variable `A`

and `z`

as well as their product `x_hat`

must use 358 MB.

The number of elements are:

- (A) 2000 x 12 x 32bit = 0.092 MB
- (z) 12 x 10 x 32bit = 0.47 KB
- (x_hat) 2000 x 10 x 32bit = 0.076 MB

Even if I double the memory, because both location and scale have to estimate, and double footprint again because gradients have to be stored, I reach maybe 1 MB.

To scale the input data, I’d like to understand where the large overhead is coming from?

I’d really appreciate any insights on that. It’s driving me nuts.

EDIT:

- even
`with torch.no_grad():`

to not compute gradients does not change the high memory footprint

Thanks in advance!

```
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
import torch.nn as nn
from pyro.infer import SVI, TraceEnum_ELBO
from torch.nn.functional import softmax
from torch.utils.data import DataLoader, TensorDataset
class MyModel(nn.Module):
def __init__(self, num_features=1000, latent_dim=12):
super().__init__()
self.num_features = num_features
self.latent_dim = latent_dim
self.guide = pyro.infer.autoguide.AutoDelta(self.model)
def model(self, x):
plates = self.get_plates(x)
with plates["latents"], plates["features"]:
z = pyro.sample("z", dist.Normal(torch.zeros(1), torch.ones(1)))
x_hat = torch.ones(size=(len(x), self.latent_dim)) @ z
x_hat = softmax(x_hat, dim=-2)
with plates["minibatch"], plates["features"]:
pyro.sample(
"obs_features",
dist.Poisson(x_hat),
obs=x,
)
def get_plates(self, x):
return {
"features": pyro.plate("features", self.num_features, dim=-1),
"latents": pyro.plate("latents", self.latent_dim, dim=-2),
"minibatch": pyro.plate("minibatch", len(x), dim=-2),
}
# CONSTANTS
N_SAMPLES = 200
NUM_FEATURES = 2000
MINIBATCH_SIZE = 10
LATENT_DIM = 12
torch.cuda.set_device(0)
device = torch.device("cuda")
torch.cuda.empty_cache()
pyro.util.set_rng_seed(0)
pyro.clear_param_store()
# Create random poisson data
data = torch.randint(0, 20, size=(N_SAMPLES, NUM_FEATURES), dtype=torch.float32)
dataloader = DataLoader(
TensorDataset(data), batch_size=MINIBATCH_SIZE, shuffle=False, pin_memory=True
)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
# Create Model
mymodel = MyModel(
num_features=NUM_FEATURES,
latent_dim=LATENT_DIM,
).cuda()
# Run SVI
optimizer = pyro.optim.Adam({"lr": 0.01})
elbo = TraceEnum_ELBO(strict_enumeration_warning=False)
svi = SVI(mymodel.model, mymodel.guide, optimizer, elbo)
for epoch in range(100):
for x in dataloader:
x = x[0].cuda()
with poutine.scale(scale=1 / MINIBATCH_SIZE):
loss = svi.step(x=x)
```