I have a model that is using mcmc and nuts for training. I am not sure what the issue is but some days when I run the training it has no problems, generally taking around 40 minutes to train the model. Other days, it stalls consistently at around 30 iterations and has an estimated time of 20+ hours. I believe this may be due to adaptive step size, but I do use a jax random key. Is there a way to make this more consistent, so I don’t have to worry about the model not training all day sometimes?
Maybe you can try some other init strategy Runtime Utilities — NumPyro documentation
I tried every strategy with no luck, most of them actually seem to be slower unfortunately.
It is odd to me that for pretty much an entire day, the training takes less than an hour, and for another day it cannot get past 30 iterations. Recently it has been the latter.
Have you had any similar issues yourself?
Thank you for helping out!
this is not currently exposed in the api but you can try increasing some of the “window” sizes; see mean_accept_prob significantly different after warmup · Issue #1786 · pyro-ppl/numpyro · GitHub