Batching MCMC OOM issue

Hi all,
I’m currently working on a model with 1 GPU and I am facing an OOM error.
All the chains are loaded in the GPU memory and we exceed the 16GB of the GPU.

To fix it, I have implemented batching. (Forum 1, Forum 2)
It now runs with 2000 samples. However BFMI is low → I need to increase the samples.
OOM error is again triggered with 4000 samples. It seems that 2 chains are always kept in memory. (1 chain being 6GB, 2 chains 12GB → OOM on 16GB GPU)

Below is the experiments I have done:

Experiments so far:

XLA_PYTHON_CLIENT_PREALLOCATE XLA_PYTHON_CLIENT_ALLOCATOR num_samples num_warmup Chain method VRAM at beginning of each sample Successfully run?
False Platform 4000 10000 sequential 1. 744 MB 2. 6566MB No
False Platform 2000 10000 sequential 1. 744 MB 2. 3695 MB 3. 3695 MB 4. 3695 MB Yes
Default Default 2000 10000 sequential 1. 13,385 MB 2. 13,385 MB No

Similar to other in the forums, the 2nd batch will usually cause an OOM error. I have attempted to pass mainly these two flags: XLA_PYTHON_CLIENT_PREALLOCATE and XLA_PYTHON_CLIENT_ALLOCATOR, as suggested in the forums I have mentioned.

When sampling for 2nd batch, the GPU memory is not all released. I assume this is the previous sampling state. It does not get released until the model ran.

Current implementation of the model and the batching:

Please toggle to see:


Model definition and NUTS+MCMC configuration

def model(n_items, n_factors, n_persons, responses_mask, responses=None):
    with plate("diff_dim1",  n_items, dim=-1):
        diff = numpyro.sample("diff", dist.Normal(loc=0.0, scale=1.0))

    with plate("discrim_dim1", n_factors, dim=-1):
        with plate("discrim_dim2", n_items, dim=-2):
            discrim_offset = numpyro.sample(
                "discrim_offset", dist.LogNormal(loc=0.0, scale=1.0)
            # Need to use deterministic layer to apply q_matrix
            discrim = numpyro.deterministic(
                "discrim", discrim_offset * item_factors

    with plate("ability_dim1", n_persons, dim=-1):
        corr = numpyro.sample("ability_corr", dist.LKJ(n_factors, jnp.ones([1])))

        # Mean would always be 0
        mu = jnp.zeros(n_factors)
        ability = numpyro.sample(
            "ability", dist.MultivariateNormal(loc=mu, covariance_matrix=corr)

    # Logistic regression
    kernel =, discrim.T) + diff

# Defining sampler
kernel = NUTS(
    model, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False

mcmc = MCMC(

Batching function

_traces = []
_extra_fields = []

for i in range(chains):
    # Transfer to CPU
    samples = jax.device_put(
        mcmc.get_samples(group_by_chain=True), jax.devices("cpu")[0]
    extra_fields = jax.device_put(
        mcmc.get_extra_fields(group_by_chain=True), jax.devices("cpu")[0]

    del samples, extra_fields

    # Set warmup state to the next run
    sampler._warmup_state = sampler._last_state

# Prepare the traces for arviz
trace = {}
extras = {}
for k in _traces[0].keys():
    trace[k] = np.concatenate(list(trace[k] for trace in _traces))
for j in _extra_fields[0].keys():
    extras[j] = np.concatenate(list(extras[j] for extras in _extra_fields))

idata = az.convert_to_inference_data(trace)
iextra = az.convert_to_inference_data(extras, group="sample_stats")
az.concat(idata, iextra, inplace=True)


  1. Is there anyway to reduce the memory consumption to avoid OOM issue?
  2. Is there any suggestions on how to improve the situation? Or any mistake spotted in the implementation?
  3. Why I seem to have 2 chains at all times in memory, when I only keep the last state (which should be smaller)?
  4. Why do I run out of memory before hitting 16GB, is it fragmentation?

Thank you!

Did you try transfer_state_to_host? Markov Chain Monte Carlo (MCMC) — NumPyro documentation?

For the model, you can use LKJCholesky. It might reduce memory used if cov matrix is large.

Thank you for your reply! And very sorry for delaying on my end. I just tried transfer_state_to_host and it did helped a lot! For LKJCholesky, I have attempted that as well and this should help in a longer term as the matrix expands. :slight_smile:

For those who encounter this, this is what I have added between each batch section:

sampler = self.model.sampler(draws=draws, tune=tunes, chains=1)
for i in range(chains):'Starting batch {i} - starting memory {self.check_gpu_memory()}')
    mcmc = self.model.sample(sampler)
    # Transfer to CPU
    samples = jax.device_put(
        mcmc.get_samples(group_by_chain=True), jax.devices("cpu")[0]
    extra_fields = jax.device_put(
        mcmc.get_extra_fields(group_by_chain=True), jax.devices("cpu")[0]

    del samples, extra_fields

    # Set warmup state to the next run
    sampler._warmup_state = sampler._last_state
    del sampler._last_state
1 Like