How to put autoguide parameters on GPU

I m running simple bayesian model using self defined model function, and autoguide function such as AutoIAFNormal, it always says 'Expected object of backend CPU but got backend CUDA for argument #4 ‘mat1’

I put model in net and set net.guide = autoguide

I just simply don’t know how to put parameters from guide to GPU

there was a bug that has been fixed in dev, so you can either use the dev version of pyro or torch.set_default_tensor_type('torch.cuda.FloatTensor') (or some other tensor type)

I’m trying to use SVI. Using the solution here to put the parameters of the guide to the GPU gives the following error now:

Traceback (most recent call last):
  File "<stdin>", line 4, in <module>
  File "C:\Users\taman\anaconda3\envs\bayes_lm\lib\site-packages\tqdm\std.py", line 1133, in __iter__
    for obj in iterable:
  File "C:\Users\taman\anaconda3\envs\bayes_lm\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__
    data = self._next_data()
  File "C:\Users\taman\anaconda3\envs\bayes_lm\lib\site-packages\torch\utils\data\dataloader.py", line 560, in _next_data
    index = self._next_index()  # may raise StopIteration
  File "C:\Users\taman\anaconda3\envs\bayes_lm\lib\site-packages\torch\utils\data\dataloader.py", line 512, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "C:\Users\taman\anaconda3\envs\bayes_lm\lib\site-packages\torch\utils\data\sampler.py", line 226, in __iter__
    for idx in self.sampler:
  File "C:\Users\taman\anaconda3\envs\bayes_lm\lib\site-packages\torch\utils\data\sampler.py", line 124, in __iter__
    yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

Is there another way to get the guide on the GPU? Or any suggestions on how to mitigate this error now when using torch.set_default_tensor_type('torch.cuda.FloatTensor') ?

Thanks for any help!

Your error appears unrelated to Pyro or autoguides, since according to the stack trace the failure is inside a PyTorch dataloader. You might try passing a torch.Generator object with the correct device to your DataLoader’s constructor via the generator argument.