When performing parameter inference with NUTS on CUDA I observed a linear increase in allocated memory that leads to the termination of the processes within a few iterations. This only happens after warmup is finished, and apparently also happens on CPU. The minimum program to reproduce this behavior is:
import torch import pyro from pyro.infer.mcmc import MCMC, NUTS import pyro.distributions as dist data = torch.ones((1000, 1000)) def model(): x = pyro.sample('x', dist.Normal(0., 1.)) pyro.sample('data', dist.Normal(x*data, 0.01), obs = data) nuts_kernel = NUTS(model) posterior = MCMC(nuts_kernel, 1000, warmup_steps = 0).run()
I suspect that the data is somehow copied and stored at each iteration. However,
del posterior does not free any significant amount of memory (and it appear like only a reference and not copies to the original data object are stored). How can I avoid this excessive memory consumption?
(torch 1.01, pyro 0.3.1, python 3.6.7)