How do devices (CPU/GPU) work in Numpyro?

I think this might work if you use host_callback to perform indexing in CPU. It can be slow if the communication is expensive. See e.g. how to use host_callback in truncated binomial thread.