Dtype float16

Hi,I’m running a graph model with a large size of parameter. So I want to change the dtype of torch tensor to float16 to save the memory. But when I run elbo.compute_marginals, this error arise:

RuntimeError: “lt_cpu” not implemented for ‘Half’

Is there any way to utilize float16? Thanks!

Hi, what version of Pyro and PyTorch are you using? Can you provide a full stack trace so we can see where the failure happens?