Memory Issue

Hi,

I’m running into an OOM problem, a toy version of which is below (running with the most up-to-date versions of NumPyro and Jax, profiling with memory_profiler):

import numpy as np
import numpyro
from numpyro.infer import MCMC, DiscreteHMCGibbs, HMC
import numpyro.distributions as dist
from jax import random
import time

def model():
    index = numpyro.sample('index', dist.Categorical(probs=np.array([[0.5,0.5]])))
    mu = numpyro.sample('mu', dist.Normal(0,1), sample_shape=(2,))
    return

@profile
def test_fun():
    for i in range(10):
        kernel = DiscreteHMCGibbs(HMC(model), random_walk=True)
        mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1, progress_bar=True)
        mcmc.run(random.PRNGKey(0))
        time.sleep(1)

test_fun()

Essentially, the memory is increasing with each loop iteration. I had a look at the open issues on GitHub, and this seems to relate to my issue. In the initial post kernel and mcmc were instantiated within the loop, but the solution was then shown to work when these were instantiated outside the loop. My problem involves the former as I make changes to my model within each iteration.

Is there a way to essentially clean up the kernel and mcmc objects?

Thank you!

Hi @emyjo, could you try those clear_cache methods? I’m not sure how to properly address this issue. :frowning: IIRC the only place that we do caching is this line so you might want to reduce the caching size from 8 to 1. If it resolves the memory issue then please make a feature request to expose a global parameter in numpyro.util to modify the size. :slight_smile:

Hi @fehiepsi, thanks for the suggestions, unfortunately neither seemed to fix the issue.

I ended up adding code to clear the cache of this line. Is this the only place that I should be concerned about?

Also, I did notice this issue and was wondering if it might be connected?

How can you clear the cache there? I think you can remove jit there and move it to this line. Then the jitted function will be cached and we can think about how to do with the cache one (I don’t know how). IIRC it is the only place that we do jitting (not completely sure).

Hi @fehiepsi, thank you for your response. I moved it to the suggested line by using a decorator, but it hasn’t fixed the issue. Is using a decorator the correct way?

I’m not sure how to solve the issue actually. I thought you already have a solution, that’s why I asked.