Training pyro models that include nn.Modules on the GPU

Hi there

I’m trying to train a VAE-style model using Pyro on an A100 GPU. I’ve got all the torch tensors on the gpu device, same with the nn.Modules that are components of the model and guide.

Training feels like a similar speed on the A100 compared to training on CPU of local device (M1 max, 64 gb ram).

Relative to Numpyro, I can find less document for Pyro and training on GPUs… can someone point to any docs or fatal errors that I might be making?


Hi @hasco641

Variational Autoencoders — Pyro Tutorials 1.8.6 documentation has an example of training a VAE model on the gpu.

You can monitor GPU memory and compute utilization during training by calling nvidia-smi in the (other) terminal. If they are low that might mean that the program is not using GPU efficiently (e.g. dataloader can be rate limiting which can be alleviated by using multiple data loader workers).