Different HMC results between devices, for same seed

Hey all,
I’m sampling 10K particles from a Gaussian distribution. When I sample on my Windows10 machine and on my Linux server I get different values despite fixing the seed to same value. Is it a bug or on purpose?

Thanks in advance!

1 Like

I think it is expected due to the differences in precision of operators in different systems.

1 Like

@fehiepsi Thanks!

I’m curious, how different are they? In my experience this sort of machine precision doesn’t make things very different at all, although MCMC might be a case where imprecision builds up quickly.

Yeah, I guess it is less different if we use float64. I’m curious too.

As you can see the difference is indeed minute. Nonetheless, for me needs this difference causes headache as these particles are fed into a MCTS and we get different results due to this variability…

I have also observed differences when running a time-series forecasting model across different machines (with different CPUs – different generation Intel processors) and setting a fixed seed. When using a fixed seed I am able to replicate the same posteriors exactly on different computers with the same CPU, but cannot replicate across different computers with different CPUs. I still saw variance across runs when changing the array type to float64 from the default float32. I suspect this is resulting from differences in how Jax/XLA compilation on different CPU architectures. I am also not sure how to control for this source of variation. Even with models that reasonably converge this is still somewhat of an issue because some use cases require exactly controlling for different sources of variation in numerical experiments and in the setting I am using I cannot control for the CPU architecture of the machines that run the models. I’ve read the XLA Architecture (“How Does XLA Work” section – XLA Architecture  |  TensorFlow). I am wondering whether there may be some XLA flags that can be used to prevent certain optimizations – if the XLA compilation is the source of the variation across the difference CPUs.

@jawolf314 Could you ask the question on jax forum? If you find a solution, could also post it here to help others control the behavior? We can add a section to README to mention this. Thanks!

@fehiepsi Thank you for the recommendation! I’ve posted a related question on the Jax forum here. https://github.com/google/jax/discussions/9686