How do devices (CPU/GPU) work in Numpyro?

Hi All,

I am trying to understand how one can efficiently use GPUs when one’s data is too big to fit on GPU but can fit on CPU.

I am hopeful HMCES may be the solution. Specifically, I am hoping to subsample the data in plate, then make a copy of the subsampled data and put it on GPU using jax.device_put. Something like:

with plate("N", size = N, subsample_size = batch_size) as ind:
    batch = X[ind]
    batch = jax.device_put(batch, "gpu")
    sample("obs", dist.Normal(), obs = batch)

Do you think this might work? Will the log likelihood be executed on GPU?


you probably just want to do everything on gpu:


edit: sorry i misread your post. i’m not sure if that will work as i’ve never tried to interleave cpu and gpu in jax

I think this might work if you use host_callback to perform indexing in CPU. It can be slow if the communication is expensive. See e.g. how to use host_callback in truncated binomial thread.