A trick to debug tensor memory

I’d like to share a trick we use on the Pyro team to debug memory leaks, since this is such a common type of bug. We often use the following memory footprint analyzer:

def debug_memory():
    import collections, gc, resource, torch
    print('maxrss = {}'.format(
        resource.getrusage(resource.RUSAGE_SELF).ru_maxrss))
    tensors = collections.Counter(
        (str(o.device), str(o.dtype), tuple(o.shape))
        for o in gc.get_objects()
        if torch.is_tensor(o)
    )
    for line in sorted(tensors.items()):
        print('{}\t{}'.format(*line))

Sometimes I’ll put this in a training loop an run a program in a terminal. Let’s see how it works:

>>> debug_memory()

>>> x = torch.tensor(3,3)
>>> debug_memory()
('cpu', torch.float32, (3, 3))	1

>>> y = torch.tensor(3,3)
>>> debug_memory()
('cpu', torch.float32, (3, 3))	2

>>> z = [torch.randn(i).long() for i in range(10)]
>>> debug_memory()
('cpu', torch.float32, (3, 3))	2
('cpu', torch.int64, (0,))	1
('cpu', torch.int64, (1,))	1
('cpu', torch.int64, (2,))	1
('cpu', torch.int64, (3,))	1
('cpu', torch.int64, (4,))	1
('cpu', torch.int64, (5,))	1
('cpu', torch.int64, (6,))	1
('cpu', torch.int64, (7,))	1
('cpu', torch.int64, (8,))	1
('cpu', torch.int64, (9,))	1

>>> del x, z
>>> debug_memory()
('cpu', torch.float32, (3, 3))	1

Let me know if you have any enhancements!

EDIT 2019-05-07 added maxrss printing

EDIT 2022-04-13 add str() around o.dtype

5 Likes

For me this yields the error:

Traceback (most recent call last):
File “pyro_vae_cnn.py”, line 285, in
debug_memory()
File “pyro_vae_cnn.py”, line 49, in debug_memory
for line in sorted(tensors.items()):
TypeError: ‘<’ not supported between instances of ‘torch.dtype’ and ‘torch.dtype’

I have pyro 0.3.1 and pytorch 1.0.1

the snippet works for me with the same versions. try removing the sorted(). what python version are you using?