NUTS behaviour difffers between GPU and CPU

Hi All,

I have an unusual problem. I have been developping a model on my laptop that has a GPU and all was well. During NUTS warmup phase I didn’t see the step size dropping to very small values and it didn’t seem to utilize maximum tree depth.

Then I switched to CPU by using jax.config.update('jax_platform_name', 'cpu') at the top of my script and noticed the adaptation phase started to look very different. The reported step sizes quickly become tiny and the tree depth is exhausted.

What may be the reason for it? Do CPU and GPU use different precision?

In case it mattered, I am using this on Windows with Jaxlib compiled from source. My numpyro version in 0.7.

Any advice welcome!

Probably. XLA has different backends for cpu/gpu/tpu so the results might be a bit different. I’m not sure what’s happening in your case. But it is a bit strange.