RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 'index'


#1

I am unsure of how to load the torch.param in the guide to the GPU

Also, as I am enclosing the model, guide and the inference inside a class, Do I need to load the instance of the class in the GPU also? like .to(device)?

The ternary_bayesian_neural_network is the class which contains the model,guide and the SVI inference.


#2

In the above like of cuda implementation, the pyro.param is a randn, and it worked. But how do I modify it here?


#3

Hi @srikanthram, I would recommend to use torch.set_default_tensor_type(torch.cuda.FloatTensor) to avoid “cpu->gpu” casting overhead. If you don’t want to set default tensor type, then you should initialize each distribution with cuda tensors. For example, use dist.Normal(torch.tensor(0., device="cuda"), torch.tensor(1., device="cuda")) instead of `dist.Normal(0., 1.).

As a side note, if you use class to define model and guide, then you don’t need to use pyro.param if you define your class as a subclass of a torch.nn.Module. And to cast your parameters to cuda, just simply call your_module.cuda(). You can follow the pattern in custom objectives tutorial to train your guide. And replace default parameters in your distributions by e.g. dist.Normal(some_param.new_tensor(0.), some_param.new_tensor(1.)).


#4

@fehiepsi Thank you so much for all the information. All my tensors are gpu enabled and the code is running on the gpu


#5

@fehiepsi, You had mentioned that if if used a class to define my model and guide then I would not need to use pyro.param. I went through the custom objectives tutorial. In this line
optimizer = torch.optim.Adam(my_parameters, {“lr”: 0.001, “betas”: (0.90, 0.999)})
what and how do I provide my_parameters without mentioning pyro.param?


#6

@srikanthram You can replace my_parameters by your_nn.parameters(). In you model and guide, you have to replace

p = pyro.param("p", torch.tensor(1.))

by

p = self.p  # here p is your nn parameters.

A drawback is that: you have to manipulate constraints of your parameters by yourself.