Can't make a simple classification model work

Hello everybody!
I’m quite new here in Pyro and love the framework, but I kind of struggle to understand completely how to use PyroModules and bayesian model properly.
I’m trying to do a simple classifier module embedding with a custom guide, with the output of the classifier given a Dirichlet prior and its weight / bias a Normal prior. Here, the guide implements the network used for recognition, and is used directly to predict the corresponding class.
I already succeeded by implementing the network in the model and the guide using autoguide, but for some reason I would like to do it this way. However, even the loss is optimized, the predictions are terrible at convergence (see image below). So : first, is this structure compatible with Pyro optimizers? If yes, what is the error with code (below)? Thank you very much!

class BayesianMLP(PyroModule):
    def __init__(self, dim_in, dim_out, nlayers):
        super().__init__()
        dims = [dim_in] + [300]*(nlayers-1) + [dim_out]
        self.dim_in = dim_in; self.dim_out = dim_out; self.nlayers = nlayers
        for l in range(nlayers):
            self.__setattr__('linear_%d'%l, PyroModule[nn.Linear](dims[l], dims[l+1]))
            # if l < nlayers-1:
            self.__setattr__('nnlin_%d'%l, nn.ReLU())
        self.init_modules()

    def init_modules(self):
        for l in range(self.nlayers):
            current_linear =  self.__getattr__('linear_%d'%l)
            current_linear.weight = PyroSample(dist.Normal(torch.zeros_like(current_linear.weight),
                                                           torch.ones_like(current_linear.weight)).to_event(2))
            current_linear.bias = PyroSample(dist.Normal(torch.zeros_like(current_linear.bias),
                                                         torch.ones_like(current_linear.bias)).to_event(1))
        self.probs = PyroSample(dist.Dirichlet(1/self.dim_out * torch.ones(self.dim_out)))

    def forward(self, x_data, y_data=None):
        out = x_data.reshape(x_data.shape[0], -1)
        for l in range(self.nlayers):
            current_linear =  self.__getattr__('linear_%d'%l)
            out = self.__getattr__('nnlin_%d'%l)(nn.functional.linear(out, current_linear.weight, current_linear.bias))
        # with pyro.plate('data', x_data.shape[0]):
        preds = pyro.sample("class", dist.Categorical(self.probs.expand((x_data.shape[0], 10))), obs=y_data)
        return preds


class BayesianMLPGuide(PyroModule):
    def __init__(self, model):
        super().__init__()
        self.nlayers = model.nlayers
        for l in range(model.nlayers):
            current_linear = model.__getattr__('linear_%d' % l)
            self.__setattr__('linear_%d_weight_loc'%l, PyroParam(1e-3*torch.randn_like(current_linear.weight)))
            self.__setattr__('linear_%d_weight_std'%l, PyroParam(1e-3*torch.sigmoid(torch.randn_like(current_linear.weight)), dist.constraints.positive))
            self.__setattr__('linear_%d_bias_loc'%l, PyroParam(1e-3*torch.randn_like(current_linear.bias)))
            self.__setattr__('linear_%d_bias_std'%l, PyroParam(1e-3*torch.sigmoid(torch.randn_like(current_linear.bias)), dist.constraints.positive))

    def forward(self, x_data, y_data=None):
        outs = {}
        out = x_data.reshape(x_data.shape[0], -1)
        for l in range(self.nlayers):
            weight = dist.Normal(self.__getattr__('linear_%d_weight_loc'%l), self.__getattr__('linear_%d_weight_std'%l))
            bias = dist.Normal(self.__getattr__('linear_%d_bias_loc'%l), self.__getattr__('linear_%d_bias_std'%l))
            weight = pyro.sample("linear_%d.weight"%l, weight.to_event(2))
            bias = pyro.sample("linear_%d.bias" % l, bias.to_event(1))
            outs["linear_%d.weight"%l] = weight
            outs["linear_%d.bias"%l] = bias
            out = torch.nn.functional.linear(out, weight, bias)
            out = torch.nn.functional.relu(out)
        # with pyro.plate('data', x_data.shape[0]):
        probs = pyro.sample("probs", dist.Delta(nn.functional.softmax(out, -1)).to_event(1))
        preds = pyro.sample("class", dist.Categorical(probs=probs).to_event(1))
        outs['probs'] = probs
        outs['class'] = preds
        return outs