When performing parameter inference with NUTS on CUDA I observed a linear increase in allocated memory that leads to the termination of the processes within a few iterations. This only happens after warmup is finished, and apparently also happens on CPU. The minimum program to reproduce this behavior is:
import torch
import pyro
from pyro.infer.mcmc import MCMC, NUTS
import pyro.distributions as dist
data = torch.ones((1000, 1000))
def model():
x = pyro.sample('x', dist.Normal(0., 1.))
pyro.sample('data', dist.Normal(x*data, 0.01), obs = data)
nuts_kernel = NUTS(model)
posterior = MCMC(nuts_kernel, 1000, warmup_steps = 0).run()
I suspect that the data is somehow copied and stored at each iteration. However, del posterior
does not free any significant amount of memory (and it appear like only a reference and not copies to the original data object are stored). How can I avoid this excessive memory consumption?
(torch 1.01, pyro 0.3.1, python 3.6.7)