Reducing MCMC memory usage

I am running NUTS/MCMC (on multiple CPU cores) for a quite large dataset (400k samples) for 4 chains x 2000 steps. actually ran until the end, but then died with an out-of-memory exception; I assume upon trying to gather all results. (There might be some unnecessary memory duplication going on in this step?)

Are there any “quick fixes” to reduce the memory footprint of MCMC? For instance, can I somehow specify that I only need certain sites stored and returned in the trace object?

Note that as opposed to many other OOM-related questions, my problem is not GPU memory; it’s main machine memory.

1 Like

Hi @ewipe, could you provide a simple example? You can create a fake dataset with a simple (e.g. linear regression) model.

Hi @fehiepsi, sure! A simplified version of my actual model / use case would look like this:

import numpy as np
import jax
from jax import random
from numpyro import sample, deterministic, plate
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
def model(gender_data=None, diag_data=None):

    # some population-level params
    disease_a_male = sample("disease_a_male", dist.Normal(loc=0.0, scale=2.0))
    disease_b_male = sample("disease_b_male", dist.Normal(loc=0.0, scale=2.0))
    diagnosis_logits_mdd = sample("diagnosis_logits_mdd", dist.Normal(loc=0.0, scale=2.0))
    diagnosis_logits_male = sample("diagnosis_logits_male", dist.Normal(loc=0.0, scale=1.0))
    nobs = 1 if gender_data is None else len(gender_data)  

    # individual observations / samples
    with plate("N", nobs):
        # observed site A
        male = sample("male", dist.Bernoulli(np.ones((nobs,))*0.5), obs=gender_data)

        # continuous-valued latent state in [0, 1] ~ Kumaraswamy(a,b)
        disease_a = deterministic("disease_a", jax.nn.softplus(disease_a_male * male))
        disease_b = deterministic("disease_b", jax.nn.softplus(disease_b_male * male))
        #mdd_state = sample("mdd_state", dist.Kumaraswamy(disease_a, disease_b)) 
        # basic reparametrization of Kumaraswamy(a,b) in terms of Unif(0,1).
        mdd_base = sample("mdd_base", dist.Uniform(np.zeros((nobs,)), np.ones((nobs,))))
        mdd_state = deterministic("mdd_state", (1 - (1 - mdd_base)**(1/disease_b))**(1/disease_a))

        # observed site B
        diagnosis_logits = deterministic("diagnosis_logits", diagnosis_logits_mdd * mdd_state + diagnosis_logits_male * male)
        mdd_diagnosis = sample("mdd_diagnosis", dist.Bernoulli(logits=diagnosis_logits), obs=diag_data)

# Simulate data
data = Predictive(model, num_samples=40000)(random.PRNGKey(1))

# Infer params from simulated data
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_chains=4, num_warmup=500), gender_data=data['male'].squeeze(), diag_data=data['mdd_diagnosis'].squeeze())

After completion, the mcmc object will contain float32 arrays of shape (n_mcmc, n_obs) = (8k, 400k) for each of disease_a, disease_b, mdd_base, mdd_state, and diagnosis_logits, which each of them taking up 12.8 GB of memory. (And again, more as a side-note, my impression was that some not-in-place array copying might be going on during the final stages / sample collection of the call?)

Is there any way to make the MCMC call not store all of these per-sample variables if I only need a trace for some of them? Is there a way to make them float16 instead of float32 (and would that be a good idea)? Any other tips or tricks?

way to make the MCMC call not store all of these per-sample variables

Currently, the collect_fields argument in only applies to nested attributes. I think you can extend its functionality at this line to also apply such collect operator for dict items (using e.g. this stackoverflow solution). Please feel free to make a PR.

some not-in-place array copying might be going on during the final stages

Yes, your intuition seems correct. We converted a flatten array into the desired dict format at this line. I’m not sure what’s a good solution for this. You might want to replace vmap(unravel_fn)(collection) by, collection) first to see it could help reduce memory requirements.

Is there a way to make them float16 instead of float32

It would be tricky with the current api.

Any other tips or tricks?

You can draw samples sequentially using post_warmup_state; this will greatly reduce your memory requirement. You can convert/save_to_dict/store chunks of samples to numpy arrays, float16, etc. then concatenate them together later.

1 Like

@fehiepsi, I’ve had the same problem with not-in-place array copying after sampling causing a GPU memory error, and I’ve switched to sampling sequentially using post_warmup_state.

I do it the following way:

mcmc_samples = [None] * (n_samples // 1000)
# set up MCMC
self.mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=1000, num_chains=n_chains)
for i in range((n_samples) // 1000):
    print(f"Batch {i+1}")
    # run MCMC for 1000 samples, self.spliced, self.unspliced)
    # store samples transferred to CPU
    mcmc_samples[i] = jax.device_put(self.mcmc.get_samples(), jax.devices("cpu")[0])
    # reset the mcmc before running the next batch
    self.mcmc.post_warmup_state = self.mcmc.last_state

However, if I sample 2000 times in two sequential batches the first one goes fine, while the second still causes a memory error upon gathering the results:

Running MCMC in batches of 1000 samples, 2 batches in total.
First batch will include 1000 warmup samples.
Batch 1
sample: 100%|██████████| 2000/2000 [11:18<00:00,  2.95it/s, 1023 steps of size 5.13e-06. acc. prob=0.85]
Batch 2
sample: 100%|██████████| 1000/1000 [05:48<00:00,  2.87it/s, 1023 steps of size 5.13e-06. acc. prob=0.85]
2023-11-24 14:43:23.854505: W external/tsl/tsl/framework/] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.56GiB (rounded to 2750440192)requested by op 

What can be causing the memory usage of the second batch to be larger than that of the first batch? Is it possible to somehow “clean up” MCMC after the first 1000 samples so that the second round doesn’t lead to increased memory usage?

Could you make a feature request? I think we can add some methods like to_numpy() or something.