I’d like to share a trick we use on the Pyro team to debug memory leaks, since this is such a common type of bug. We often use the following memory footprint analyzer:
def debug_memory():
import collections, gc, resource, torch
print('maxrss = {}'.format(
resource.getrusage(resource.RUSAGE_SELF).ru_maxrss))
tensors = collections.Counter(
(str(o.device), str(o.dtype), tuple(o.shape))
for o in gc.get_objects()
if torch.is_tensor(o)
)
for line in sorted(tensors.items()):
print('{}\t{}'.format(*line))
Sometimes I’ll put this in a training loop an run a program in a terminal. Let’s see how it works:
>>> debug_memory()
>>> x = torch.tensor(3,3)
>>> debug_memory()
('cpu', torch.float32, (3, 3)) 1
>>> y = torch.tensor(3,3)
>>> debug_memory()
('cpu', torch.float32, (3, 3)) 2
>>> z = [torch.randn(i).long() for i in range(10)]
>>> debug_memory()
('cpu', torch.float32, (3, 3)) 2
('cpu', torch.int64, (0,)) 1
('cpu', torch.int64, (1,)) 1
('cpu', torch.int64, (2,)) 1
('cpu', torch.int64, (3,)) 1
('cpu', torch.int64, (4,)) 1
('cpu', torch.int64, (5,)) 1
('cpu', torch.int64, (6,)) 1
('cpu', torch.int64, (7,)) 1
('cpu', torch.int64, (8,)) 1
('cpu', torch.int64, (9,)) 1
>>> del x, z
>>> debug_memory()
('cpu', torch.float32, (3, 3)) 1
Let me know if you have any enhancements!
EDIT 2019-05-07 added maxrss printing
EDIT 2022-04-13 add str() around o.dtype