Issues with running CEVAE with cuda on colab

Hi I am trying to run fit on the cevae module with cuda on colab and I am experiencing a couple of errors. Initially, by running

torch.set_default_tensor_type("torch.cuda.FloatTensor")

I get this:

RuntimeError Traceback (most recent call last)

<ipython-input-6-4b0d2750c8e0> in <module>()
     26     losses_fold = cevae_fold.fit(X_train_fold, t_train_fold, y_train_fold,
     27                     num_epochs=100,
---> 28                     learning_rate=1e-3)
     29     kfold_loss_dict[i] = losses_fold
     30     # Evaluate #

[…]

/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py in __iter__(self)
    122             yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
    123         else:
--> 124             yield from torch.randperm(n, generator=generator).tolist()
    125 
    126     def __len__(self) -> int:

RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

If I don’t run the default tensor type command and manually send train/treatment/outcome tensors and the cevae model on cuda with .cuda() I get the following error:

/usr/local/lib/python3.7/dist-packages/torch/distributions/normal.py in log_prob(self, value)
 75         var = (self.scale ** 2)
 76         log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
---> 77         return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
 78 
 79     def cdf(self, value):

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

It seems that in both cases modules that are called by pyro.contrib.cevae.CEVAE do not send the prerequisites to the device (e.g dataloaders, tensors). How can I send everything to cuda? Can I instantiate the model differently with my one dataloaders, encoder, decoder etc and send everything to cuda? The documentation regarding the separate modules Model() and Guide() is not that clear to me.

Thank you for your time!