Hello everybody!
I’m quite new here in Pyro and love the framework, but I kind of struggle to understand completely how to use PyroModules and bayesian model properly.
I’m trying to do a simple classifier module embedding with a custom guide, with the output of the classifier given a Dirichlet prior and its weight / bias a Normal prior. Here, the guide implements the network used for recognition, and is used directly to predict the corresponding class.
I already succeeded by implementing the network in the model and the guide using autoguide, but for some reason I would like to do it this way. However, even the loss is optimized, the predictions are terrible at convergence (see image below). So : first, is this structure compatible with Pyro optimizers? If yes, what is the error with code (below)? Thank you very much!
class BayesianMLP(PyroModule):
def __init__(self, dim_in, dim_out, nlayers):
super().__init__()
dims = [dim_in] + [300]*(nlayers-1) + [dim_out]
self.dim_in = dim_in; self.dim_out = dim_out; self.nlayers = nlayers
for l in range(nlayers):
self.__setattr__('linear_%d'%l, PyroModule[nn.Linear](dims[l], dims[l+1]))
# if l < nlayers-1:
self.__setattr__('nnlin_%d'%l, nn.ReLU())
self.init_modules()
def init_modules(self):
for l in range(self.nlayers):
current_linear = self.__getattr__('linear_%d'%l)
current_linear.weight = PyroSample(dist.Normal(torch.zeros_like(current_linear.weight),
torch.ones_like(current_linear.weight)).to_event(2))
current_linear.bias = PyroSample(dist.Normal(torch.zeros_like(current_linear.bias),
torch.ones_like(current_linear.bias)).to_event(1))
self.probs = PyroSample(dist.Dirichlet(1/self.dim_out * torch.ones(self.dim_out)))
def forward(self, x_data, y_data=None):
out = x_data.reshape(x_data.shape[0], -1)
for l in range(self.nlayers):
current_linear = self.__getattr__('linear_%d'%l)
out = self.__getattr__('nnlin_%d'%l)(nn.functional.linear(out, current_linear.weight, current_linear.bias))
# with pyro.plate('data', x_data.shape[0]):
preds = pyro.sample("class", dist.Categorical(self.probs.expand((x_data.shape[0], 10))), obs=y_data)
return preds
class BayesianMLPGuide(PyroModule):
def __init__(self, model):
super().__init__()
self.nlayers = model.nlayers
for l in range(model.nlayers):
current_linear = model.__getattr__('linear_%d' % l)
self.__setattr__('linear_%d_weight_loc'%l, PyroParam(1e-3*torch.randn_like(current_linear.weight)))
self.__setattr__('linear_%d_weight_std'%l, PyroParam(1e-3*torch.sigmoid(torch.randn_like(current_linear.weight)), dist.constraints.positive))
self.__setattr__('linear_%d_bias_loc'%l, PyroParam(1e-3*torch.randn_like(current_linear.bias)))
self.__setattr__('linear_%d_bias_std'%l, PyroParam(1e-3*torch.sigmoid(torch.randn_like(current_linear.bias)), dist.constraints.positive))
def forward(self, x_data, y_data=None):
outs = {}
out = x_data.reshape(x_data.shape[0], -1)
for l in range(self.nlayers):
weight = dist.Normal(self.__getattr__('linear_%d_weight_loc'%l), self.__getattr__('linear_%d_weight_std'%l))
bias = dist.Normal(self.__getattr__('linear_%d_bias_loc'%l), self.__getattr__('linear_%d_bias_std'%l))
weight = pyro.sample("linear_%d.weight"%l, weight.to_event(2))
bias = pyro.sample("linear_%d.bias" % l, bias.to_event(1))
outs["linear_%d.weight"%l] = weight
outs["linear_%d.bias"%l] = bias
out = torch.nn.functional.linear(out, weight, bias)
out = torch.nn.functional.relu(out)
# with pyro.plate('data', x_data.shape[0]):
probs = pyro.sample("probs", dist.Delta(nn.functional.softmax(out, -1)).to_event(1))
preds = pyro.sample("class", dist.Categorical(probs=probs).to_event(1))
outs['probs'] = probs
outs['class'] = preds
return outs
