Hi All,
I am trying to understand how one can efficiently use GPUs when one’s data is too big to fit on GPU but can fit on CPU.
I am hopeful HMCES may be the solution. Specifically, I am hoping to subsample the data in plate, then make a copy of the subsampled data and put it on GPU using jax.device_put
. Something like:
with plate("N", size = N, subsample_size = batch_size) as ind:
batch = X[ind]
batch = jax.device_put(batch, "gpu")
sample("obs", dist.Normal(), obs = batch)
Do you think this might work? Will the log likelihood be executed on GPU?
Thanks!