I believe the person in the topic you linked did a reasonable job setting up their model in the end, so a narrow, huggingface-specific answer to your second question would be to start from their more recent code snippets in that and other topics, but since you also asked for general advice here’s a bit of context on why you’re having trouble and why that approach is unlikely to succeed:
Bayesian neural networks are a subject of active research. Broadly speaking, there are only two problem regimes where people who are not experts in this research can expect them to work reliably given current inference technology.
- When the number of datapoints is (ideally much) larger than the number of parameters, variational inference using local reparametrization with independent variational distributions per layer or per parameter may produce reasonable predictive uncertainty estimates.
- When running HMC for a long time is computationally feasible, i.e. when your model and dataset are small enough that you can run forward passes on your entire dataset and store many (10s-100s) copies of your weights for prediction
Other approaches remain unproven at best and have generally not been evaluated or scaled up beyond small feed-forward networks on a few toy datasets.
Unfortunately, neither of these two regimes match your particular problem (post-hoc calibration of a very large pretrained neural network), so even if it were much easier to construct such a model in Pyro it is unlikely that variational inference would produce sensible posterior or predictive uncertainty estimates, nor are there other off-the-shelf techniques, even non-Bayesian ones, that would be likely to do any better at calibrating something as large and complex as GPT-2.
For that reason the Pyro core team have tended not to invest a lot of developer time in Bayesian neural network tooling, though we’re certainly open to community contributions in this direction - see TyXe for an example of a Pyro library that addresses the problems of BNN prior and guide creation/initialization and automatic local reparametrization.