Hi @fehiepsi
Thanks a lot for your help!
I don’t think pyro.clear_param_store() is the problem. Also I don’t want to increase jitter beyond 1e-4, which does not help. I think I should try torch.float64. Another way I just remembered I used before is to add a prior on lengthscale to upper bound it, as the problem usually comes when gradient descent reaches a high lengthscale value.