Can I switch the platform to cpu after fitting a numpyro model?

My dataset is big so the loglikelihood does not fit in the gpu memory. I want to use cpu instead but I cannot use set_platform because it is said it only works in the beginning of the program. Is there any solution?

I think it should be possible by moving posterior samples and “test” array to CPU using jax.device_put(). At least that’s what I usually do when running Gaussian processes with numpyro.

Thanks, I think it works!

I allocated the mcmc.get_samples() to the CPU using jax.device_put(), but now predict tries to allocate 600GB of memory. Any ideas?