I see. I think the issue is in dist.Delta(Nx < p.cpu())
. You probably want it to be something like dist.Delta((Nx < p).type(torch.float))
. As I mentioned earlier, this will still give you problems if you are going to run HMC.
2 Likes