My dataset is big so the loglikelihood does not fit in the gpu memory. I want to use cpu instead but I cannot use set_platform because it is said it only works in the beginning of the program. Is there any solution?
I think it should be possible by moving posterior samples and “test” array to CPU using jax.device_put(). At least that’s what I usually do when running Gaussian processes with numpyro.
Thanks, I think it works!
I allocated the mcmc.get_samples() to the CPU using jax.device_put(), but now predict tries to allocate 600GB of memory. Any ideas?