GPU memory preallocated and not released between batches

Hi. I think my MCMC program runs out of GPU memory because it is not being released between batches. I know that this issue has been raised before on this forum and on github (e.g., #539) but I think I tried all of the proposed fixes without much luck. Here’s the type of errors I’m getting:

ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1529020400 bytes.

The above message does not really make sense as this machine has 16GB of GPU memory. However, it appears that most of it is getting preallocated and not released.

Initially, I tried:

batches = []
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=100)
for i in range(10):
    mcmc.run(random.PRNGKey(i), *model_args)
    batches.append(mcmc.get_samples())
    mcmc._warmup_state = mcmc._last_state

then:

batches = []
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=100)

for i in range(10):

    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
    os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

    mcmc.run(random.PRNGKey(i), *model_args)
    batches.append(mcmc.get_samples())
    mcmc._warmup_state = mcmc._last_state

    del os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]
    gc.collect()

Both hit the above error after 1 or 2 batches (when attempting to run the next batch). As always, any guidance, advice would be greatly appreciated. Thank you!

Just narrowed down the issue a little bit. The reason jax was still allocating 90% of the GPU memory is that I needed to run os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' ahead of my imports. So, now it no longer does that. However, memory is still not getting deallocated between batches, even with os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform', so it still errors out after a few batches.

I’m not sure if it helps but you might try to add

list(mcmc.get_samples().values())[0].block_until_ready()

between iterations (and might try MCMC(..., jit_model_args=True)). What’s happening with your gc.collect()? Do you see memory leak between iterations?