Lifting a torch.nn.Dataparallel object

Hi!

I want to run pyro on two GPU’s. The context is the following:

  1. I have a deterministic torch model
  2. SInce I have to GPU’s i use model = torch.nn.DataParallel(model)
  3. I am putting priors on the weights of the model using the standard framework and lifting the DataParallel object

When I execute, it gives me the error:

“Broadcast function not implemented for CPU tensors”

I would appreciate any tips or help!

Thank you!

Best regards,
Robert

i havent used DataParallel before but from this issue it seems as if you have non-CUDA tensors in your code.

Thank you very much for your help! Using .cuda() on the model fixed it! I was using .to(device) before and I think that might have been the issue.