High level pyro models- and training loops

Does there exist, or have anyone had any success with building robust, standardized training loops that works like e.g. pytorch-lightning? Ive tried to use pytorch-lightning to train a pyro model with varying sucess. See

It kinda works, but especially the optimizer and the checkpointing is problematic. And I havent tried gpu or multi-gpu yet.

have anyone else tried and got any good success with standardizing this?

1 Like

Hi! I am also interested in integrating PyTorch-Lightning with Pyro (and GPyTorch) as well.

I wanted to know if you’ve had any success with this?

I tried but found that the optimizers of pyro didnt fit too nicely and gave up. However, it does seem to be easier to integrate custom optimizers in lightning now when I reviewed it, so it might be worth another go.

I’m not familiar with Lightning, but if you’re doing variational inference and can set up your model and guide in a single nn.Module you can use the low-level pattern demonstrated in our custom VI objective tutorial to compute the loss, gradient and optimization step directly in the usual PyTorch fashion.

I would imagine that one could then adapt this to Lightning without too much trouble? We would certainly welcome a pull request with a worked example, or bug reports/feature requests if anything is blocking such an example.

I just did something similar recently, with a up-to-date versions of Pyro and Pytorch Lightning.
This is just a simple Variational Graph Autoencoder with inner product decoder and embedding look-up encoder. For each node I try to optimize their embedding’s location and variance variational parameters using variational inference.

https://gist.github.com/rapharomero/6b0dcfe03d20e0e0fd6cf5fea67fa8f7

1 Like

@rapharomero nice! I think directly calling PyroModule.parameters() should also work:

def configure_optimizers(self):
        return torch.optim.Adam(self.guide.parameters(), lr=0.1, betas=(0.90, 0.999))

I’m working on distributed training of scVI models and faced the same issue. If anyone is interested our implementation is here.

1 Like