Using Stochastic Weight Averaging (SWA) in Pyro?

Dear Pyro Team,

I hope this message finds you well. I have recently learned about the benefits of Stochastic Weight Averaging (SWA) in improving the performance of well-tuned models across various practical applications. I noticed that PyTorch 1.6 introduced the torch.optim.swa_utils module, which provides a convenient way to implement SWA.

I am currently working with Pyro and I am curious to know if there is an equivalent or similar functionality available in Pyro to easily implement SWA. Are there any custom functions or modules specifically designed for implementing SWA in Pyro? I would greatly appreciate any guidance or advice you can provide in this regard.

Thank you for your time and assistance.

Best regards,

there is no such functionality. the closest would be tyxe which offers some pyro-based machinery for bayesian neural networks.

generally speaking we choose to keep most bayesian deep learning techniques out of scope, essentially because they tend to be more ad hoc than not and so it’s difficult to integrate them coherently into a probabilistic programming framework. by contrast it’s clear how e.g. discrete enumeration, stochastic variational inference, and map can be coherently combined. also we have to choose our battles, and it’s likely that packages like torch.optim.swa_utils are better suited for these sorts of things.

1 Like

Drafted an implementation here: SWA implementation by yuanqing-wang · Pull Request #3249 · pyro-ppl/pyro · GitHub