Saving GPU ram by periodically off-loading old samples to CPU RAM during sampling

I come here from pymc, where I use the function pymc.sampling_jax.sample_numpyro_nuts() which AFAIK is a wrapper for mcmc(nuts, …) in numpyro. For reference, the invocation of numpyros nuts sampler in pymc is here pymc/jax.py at main · pymc-devs/pymc · GitHub

The problem, in a nutshell, as I understand it is that get_samples() does not offload accumulated samples from the GPU until sampling has finished, which makes the RAM of the GPU the limiting factor for how many samples one can sample in one run.

How difficult would it be to implement some periodic off-loading of old samples from the GPU in order to remove GPU-RAM as the limiting factor for number of samples that could be collected in one run? For example, when 500 draws have been collected, off-load them to CPU RAM, and then store draws in the same buffer of GPU-RAM that is now not in use anymore.

Is this something that can be made by changing the numpyro codebase, or should I redirect this feature request to the JAX community?

EDIT: I found GPU Memory · Issue #539 · pyro-ppl/numpyro · GitHub which suggest that the feature I want would be possible to get via a loop like:

mcmc = MCMC(NUTS(test1), 100, 100)
for i in range(10):
    print("\n"+GPU_mem_state())
    mcmc.run(random.PRNGKey(i))
    samples = mcmc.get_samples()
    trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
    del samples
    mcmc._warmup_state = mcmc._last_state
    gc.collect()

This looks promising, is it so easy to restart sampling from the last state? That would be awesome!

Yes, we can restart sampling from the last state. A more official solution is to use post_warmup_state (see the example there).

1 Like

Thank you!

I almost got it working, just need to know which function to use when merging the output of get_samples() would concatenate() from numpy work for this?

    raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

    ## proof of concept, get more samples
    number_of_iterations = 5
    for i in range(number_of_iterations):
        print("Batch", i, "of samples collected", file=sys.stdout)
        pmap_numpyro.post_warmup_state = pmap_numpyro.last_state
        pmap_numpyro.run(pmap_numpyro.post_warmup_state.rng_key,
                         extra_fields=(
                             "num_steps",
                             "potential_energy",
                             "energy",
                             "adapt_state.step_size",
                             "accept_prob",
                             "diverging",
                         ),
        )
        more_raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
        raw_mcmc_samples = numpy.concatenate(raw_mcmc_samples, more_raw_mcmc_samples)

I guess you can do

jax.tree_util.tree_map(np.concatenate, raw_mcmc_samples, more_raw_mcmc_samples)
1 Like

Almost, I had to add a parenthesis

jax.tree_util.tree_map(np.concatenate, (raw_mcmc_samples, more_raw_mcmc_samples))

Thanks again!

concatenate() makes a copy so it wouldn’t scale nicely. append() would be more efficient, but the structure here is ragged, so I never got it to work. A slightly simplified version worked with

np.append(arr1, arr2, axis = 2).tolist()

but the real model I used had a more complex shape (some elements were themselves list, I assume). Instead I choosed to process each batch to completion (creating a arviz object for each batch), save to disk, and concatenate these outside of the sampling function, using arviz.concatenate()