I have a Custom_Model
class which represents Pyro’s model()
function and hence describes all parameters involved in the prior and likelihood. In addition, it contains a Bayesian Neural Network instance (from the tutorials).
After a successfull training, I wonder how I can extract this learned Bayesian network to make predictions on new data, i.e. how can I construct an instance of BayesianRegression
containing the correct guide parameters?
nn_model = ?
nn_model(x_new)
My setup looks like the following:
class ModelWrapper(PyroModule):
...
self.model = Custom_Model()
self.guide = AutoNormal(self.model)
class Custom_Model(PyroModule):
"""Implements the Pyro model() functionality"""
...
self.net = BayesianRegression(in_features=10, out_features=5)
def forward(self):
... # some other sample statements
yhat = F.log_softmax(self.net(x), dim=1)
return pyro.sample("obs_labels", dist.Categorical(logits=yhat))
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
# I moved the sample statement from the tutorial
# into the model's forward function
mean = self.linear(x).squeeze(-1)
return mean
I guess one way is to manually extract all the posterior weights from the guide, e.g. wrapper.guide._get_loc_and_scale('model.net.linear.weight')[0]
and load them into an instance of the BayesianRegression
class. However, this is very cumbersome and error prone.
Are there better ways to achieve this?