# Bayesian Classification

Hi there!
I’ve not found a simple tutorial on how to use Pyro for classifying a binary toy dataset like moon or blobs.
Therefore, my question: Is my definition of the following model suitable for classification and can be trained with SVI or MCMC without any further changes w.r.t to regression tasks? (e.g. do I need another loss than TRACE_ELBO for SVI?)

``````class BayesianClassification(PyroModule):
def __init__(self, in_features, hidden_neuron, out_features):
super().__init__()
self.layer_1 = PyroModule[nn.Linear](2, 10)
self.layer_1.weight = PyroSample(dist.Normal(0., 1.).expand([10, 2]).to_event(2))
self.layer_1.bias = PyroSample(dist.Normal(0., 10.).expand([10]).to_event(1))

self.layer_2 = PyroModule[nn.Linear](10, 10)
self.layer_2.weight = PyroSample(dist.Normal(0., 1.).expand([10, 10]).to_event(2))
self.layer_2.bias = PyroSample(dist.Normal(0., 10.).expand([10]).to_event(1))

self.layer_3 = PyroModule[nn.Linear](10, 2)
self.layer_3.weight = PyroSample(dist.Normal(0., 1.).expand([1, 10]).to_event(2))
self.layer_3.bias = PyroSample(dist.Normal(0., 10.).expand([1]).to_event(1))

def forward(self, x, y=None):
mean = F.relu(self.layer_1(x)).squeeze(-1)
mean = F.relu(self.layer_2(mean)).squeeze(-1)
mean = F.sigmoid(self.layer_3(mean).squeeze(-1))
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Bernoulli(probs=mean), obs=y)
return mean
``````

Thanks for helping!

1 Like

The model looks reasonable for me. If you are using AutoNormal guide, then it is better to use TraceMeanFieldELBO objective.

Thanks!
I tried it with the TraceMeanFieldElbo loss, but it seems not to learn the correct class boundary, e.g. of the moon dataset.
I’m using

``````self.guide = AutoDiagonalNormal(self.model)
The code looks correct to me, aside from using `AutoNormal` (for mean-field) rather than `AutoDiagonalNormal`. You might want to remove all `self.` stuff (i.e. let `model`, `guide` to be functions rather than methods, let svi to be a global variable rather than a class attribute) to see if it causes any issue.