Bayesian Classification

Hi there!
I’ve not found a simple tutorial on how to use Pyro for classifying a binary toy dataset like moon or blobs.
Therefore, my question: Is my definition of the following model suitable for classification and can be trained with SVI or MCMC without any further changes w.r.t to regression tasks? (e.g. do I need another loss than TRACE_ELBO for SVI?)

class BayesianClassification(PyroModule):
    def __init__(self, in_features, hidden_neuron, out_features):
        self.layer_1 = PyroModule[nn.Linear](2, 10)
        self.layer_1.weight = PyroSample(dist.Normal(0., 1.).expand([10, 2]).to_event(2))
        self.layer_1.bias = PyroSample(dist.Normal(0., 10.).expand([10]).to_event(1))

        self.layer_2 = PyroModule[nn.Linear](10, 10)
        self.layer_2.weight = PyroSample(dist.Normal(0., 1.).expand([10, 10]).to_event(2))
        self.layer_2.bias = PyroSample(dist.Normal(0., 10.).expand([10]).to_event(1))

        self.layer_3 = PyroModule[nn.Linear](10, 2)
        self.layer_3.weight = PyroSample(dist.Normal(0., 1.).expand([1, 10]).to_event(2))
        self.layer_3.bias = PyroSample(dist.Normal(0., 10.).expand([1]).to_event(1))

    def forward(self, x, y=None):
        mean = F.relu(self.layer_1(x)).squeeze(-1)
        mean = F.relu(self.layer_2(mean)).squeeze(-1)
        mean = F.sigmoid(self.layer_3(mean).squeeze(-1))
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Bernoulli(probs=mean), obs=y)
        return mean

Thanks for helping!

1 Like

The model looks reasonable for me. If you are using AutoNormal guide, then it is better to use TraceMeanFieldELBO objective.

I tried it with the TraceMeanFieldElbo loss, but it seems not to learn the correct class boundary, e.g. of the moon dataset.
I’m using = AutoDiagonalNormal(self.model)
self.adam = pyro.optim.Adam({"lr": 0.01})
self.svi = SVI(self.model,, self.adam, loss=TraceMeanField_ELBO())
for j in range(self.num_iterations):
     loss = self.svi.step(X_train, y_train)

and run VI for 1000 iterations (might be too less? - but the loss seems to converge).
Any error you can spot on the above code ?

The code looks correct to me, aside from using AutoNormal (for mean-field) rather than AutoDiagonalNormal. You might want to remove all self. stuff (i.e. let model, guide to be functions rather than methods, let svi to be a global variable rather than a class attribute) to see if it causes any issue.