I have an unusual problem. I have been developping a model on my laptop that has a GPU and all was well. During NUTS warmup phase I didn’t see the step size dropping to very small values and it didn’t seem to utilize maximum tree depth.
Then I switched to CPU by using
jax.config.update('jax_platform_name', 'cpu') at the top of my script and noticed the adaptation phase started to look very different. The reported step sizes quickly become tiny and the tree depth is exhausted.
What may be the reason for it? Do CPU and GPU use different precision?
In case it mattered, I am using this on Windows with Jaxlib compiled from source. My numpyro version in 0.7.
Any advice welcome!