Model converges on CPU but not GPU

I am currently working on implementing quite a complex Bayesian Hierarchical Model in numpyro, which is being trained using NUTS. As expected, I see a considerable performance improvement when training on a GPU compared to CPU. Unfortunately, for some reason when training on GPU some of the chains will often get stuck at their initial values meaning that the model does not converge at all. When I train on CPU, however, the model converges very well with the expected hyper parameter values for this data set.

This occurs when using either init_to_median() and init_to_sample() and with the same random seed for the PRNG key across both the CPU and GPU trainings. When I use init_to_values() and start out the GPU training with roughly the expected parameter values, the model does then converge, but going forward this is not desirable as obviously it won’t generalise well to new data. Has anyone else encountered something similar or does anyone know why the training would behave so differently on GPU compared to CPU?

are you using 64 bit precision? differences between cpu and gpu computations tend to be larger at lower precision and smaller at higher precision

No I wasn’t, I’ve now run it again on GPU with 64 bit precision and it did converge which is reassuring, thanks for the suggestion! The down side is that the run time was roughly twice as long this way. As the model does converge on GPU with 32 bit precision with careful initialisation, I was curious as to whether there are any differences in how the initialisation works between CPU and GPU? It seems like that is the main difference in this case.

Thanks for your help!

i generally recommend using 64-bit precision for hmc/nuts.

numpyro code is generically agnostic to the hardware. that said, what the hardware does when it decides to do a floating point operation is up to it and i generally don’t have much insight into that. i guess one thing that happens is that gpu computations are sometimes done in a non-deterministic way because that’s faster with the consequence that the order of aggregation isn’t static. i suggest using google to learn more e.g.