Reducing memory requirement for NUTS

When performing parameter inference with NUTS on CUDA I observed a linear increase in allocated memory that leads to the termination of the processes within a few iterations. This only happens after warmup is finished, and apparently also happens on CPU. The minimum program to reproduce this behavior is:

import torch
import pyro
from pyro.infer.mcmc import MCMC, NUTS
import pyro.distributions as dist

data = torch.ones((1000, 1000))
def model():
    x = pyro.sample('x', dist.Normal(0., 1.))
    pyro.sample('data', dist.Normal(x*data, 0.01), obs = data)
nuts_kernel = NUTS(model)
posterior = MCMC(nuts_kernel, 1000, warmup_steps = 0).run()

I suspect that the data is somehow copied and stored at each iteration. However, del posterior does not free any significant amount of memory (and it appear like only a reference and not copies to the original data object are stored). How can I avoid this excessive memory consumption?

(torch 1.01, pyro 0.3.1, python 3.6.7)

Thanks for the clear report! I have created an issue to follow up on github. The memory would increase per iteration since we are storing num_samples execution traces, which could be more problematic on the GPU because of higher memory limitation. I think the only way around this would be to store the bare minimal information from the trace as mentioned in this issue (yet again!), so I’ll definitely make sure that we address this as part of the refactoring.