- What tutorial are you running?

Example: Causal Effect VAE — Pyro Tutorials 1.7.0 documentation - What version of Pyro are you using?

1.7.2 - Please link or paste relevant code, and steps to reproduce.

Hi I am trying to run fit on the cevae module with cuda on colab and I am experiencing a couple of errors. Initially, by running

`torch.set_default_tensor_type("torch.cuda.FloatTensor")`

I get this:

RuntimeError Traceback (most recent call last)

```
<ipython-input-6-4b0d2750c8e0> in <module>()
26 losses_fold = cevae_fold.fit(X_train_fold, t_train_fold, y_train_fold,
27 num_epochs=100,
---> 28 learning_rate=1e-3)
29 kfold_loss_dict[i] = losses_fold
30 # Evaluate #
```

[…]

```
/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py in __iter__(self)
122 yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
123 else:
--> 124 yield from torch.randperm(n, generator=generator).tolist()
125
126 def __len__(self) -> int:
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
```

If I don’t run the default tensor type command and manually send train/treatment/outcome tensors and the cevae model on cuda with .cuda() I get the following error:

```
/usr/local/lib/python3.7/dist-packages/torch/distributions/normal.py in log_prob(self, value)
75 var = (self.scale ** 2)
76 log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
---> 77 return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
78
79 def cdf(self, value):
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
```

It seems that in both cases modules that are called by pyro.contrib.cevae.CEVAE do not send the prerequisites to the device (e.g dataloaders, tensors). How can I send everything to cuda? Can I instantiate the model differently with my one dataloaders, encoder, decoder etc and send everything to cuda? The documentation regarding the separate modules Model() and Guide() is not that clear to me.

Thank you for your time!