How to move MCMC run on GPU to CPU

Hi NumPyro,

I’ve run an MCMC sampling inference on a model I have on the GPU. The sampling finished successful and I was able to get posterior samples.

However when trying to vizualize my samples and traces using Arviz az.summary or az.from_numpyro and az.plot_trace I get a RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 11044800000 bytes error. I have been trying to move the numpyro.infer.mcmc.MCMC object resulting from the sampling to the CPU to avoid this error. However, I haven’t been able to figure this out how to do this.

Is there any way I could move the numpyro.infer.mcmc.MCMC object resulting from the sampling from GPU to CPU?

1 Like

I think I figured it out.

cpus = jax.devices("cpu")
mcmc._states = jax.device_put(mcmc._states, cpus[0])
mcmc._kwargs = jax.device_put(mcmc._kwargs, cpus[0])

with mcmc the numpyro.infer.mcmc.MCMC object, seems to work.

1 Like

I would be nice to identify and fix this issue upstream in arviz. :slight_smile: I guess there are some missing jax.device_get logic in arviz.

What do you mean with " there are some missing jax.device_get logic in arviz "?

Should this be in Arviz? Or should the numpyro.infer.mcmc.MCMC object have a device_put method?

I think you can use device_get to move your device arrays (on GPU e.g.) to cpu. The logic should be in arviz. If you look at the source code in my last comment, you will see that at many places, device arrays are moved to cpu (through device_get). There might be some missing ones, e.g. kwargs, that caused your issue.