I am trying to understand how one can efficiently use GPUs when one’s data is too big to fit on GPU but can fit on CPU.
I am hopeful HMCES may be the solution. Specifically, I am hoping to subsample the data in plate, then make a copy of the subsampled data and put it on GPU using
jax.device_put. Something like:
with plate("N", size = N, subsample_size = batch_size) as ind: batch = X[ind] batch = jax.device_put(batch, "gpu") sample("obs", dist.Normal(), obs = batch)
Do you think this might work? Will the log likelihood be executed on GPU?