i think you need to make sure all the params of distributions are cuda. there’s no need to move the samples. in particular replace torch.ones()
with something like prototype_tensor.new_ones()
1 Like
i think you need to make sure all the params of distributions are cuda. there’s no need to move the samples. in particular replace torch.ones()
with something like prototype_tensor.new_ones()