Hi,
I’m running into an OOM problem, a toy version of which is below (running with the most up-to-date versions of NumPyro and Jax, profiling with memory_profiler):
import numpy as np
import numpyro
from numpyro.infer import MCMC, DiscreteHMCGibbs, HMC
import numpyro.distributions as dist
from jax import random
import time
def model():
index = numpyro.sample('index', dist.Categorical(probs=np.array([[0.5,0.5]])))
mu = numpyro.sample('mu', dist.Normal(0,1), sample_shape=(2,))
return
@profile
def test_fun():
for i in range(10):
kernel = DiscreteHMCGibbs(HMC(model), random_walk=True)
mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1, progress_bar=True)
mcmc.run(random.PRNGKey(0))
time.sleep(1)
test_fun()
Essentially, the memory is increasing with each loop iteration. I had a look at the open issues on GitHub, and this seems to relate to my issue. In the initial post kernel
and mcmc
were instantiated within the loop, but the solution was then shown to work when these were instantiated outside the loop. My problem involves the former as I make changes to my model within each iteration.
Is there a way to essentially clean up the kernel
and mcmc
objects?
Thank you!