@fritzo
This is solved now. By explicitly setting my learnable parameters to be of type torch.float64
this now works. So I suspect there was an overflow error at some point which created an infinite gradient that somehow propagated to a nan value. Possibly because I’m using the torch.exp
as my inverse link
This was really frustrating to try and debug and I don’t even know what made me think of trying this, but I think there should be some checkpoint in pyro that would have alerted me