Predict logits with Pyro Bayesian neural network

Considering a Bayesian neural network for classification in Pyro like so:

class BayesianLinear(PyroModule):
    def __init__(self, in_size: int, out_size: int):
        super().__init__()
        self.bias = PyroSample(
            prior=dist.LogNormal(0, 1).expand(torch.Size([out_size])).to_event(1)
        )
        self.weight = PyroSample(
            prior=dist.Normal(0, 1).expand(torch.Size([in_size, out_size])).to_event(2)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.bias + x @ self.weight


class BNNClassifier(PyroModule):
    def __init__(self, input_size: int, latent_size: int, output_size: int):
        super().__init__()
        self.bl1 = BayesianLinear(input_size, latent_size)
        self.bl2 = BayesianLinear(latent_size, output_size)

    def forward(self, x: torch.Tensor, y_true: torch.Tensor | None = None):
        logits = F.softmax(self.bl2(F.leaky_relu(self.bl1(x))), dim=1)
        with pyro.plate('batch', x.shape[0]):
            pyro.sample('y', dist.Categorical(logits=logits), obs=y_true)

I know that after training with SVI and some AutoGuide, I can use the Predictive module
to sample parameters and target variable y corresponding to some test input x, e.g. like so:

pyro.infer.predictive.Predictive(..., num_samples=1000)(x_test)['y']

This gives me a point prediction tensor of shape [batch_size, num_samples].
What is the best way to obtain the logits for each sampled point estimator model
(a probability tensor of shape [batch_size, num_samples, num_classes]) instead?

I guess one could take the sampled model parameters from Predictive,
load them into a corresponding Pytorch model and make a forward pass using input x.
But this seems kind of like a workaround.

Hi @dschneider, IIUC you’d like to record the logits during model execution? If that’s so then you can save them using pyro.deterministic:

 def forward(self, x: torch.Tensor, y_true: torch.Tensor | None = None):
     logits = F.softmax(self.bl2(F.leaky_relu(self.bl1(x))), dim=1)
+    pyro.deterministic("logits", logits)
     with pyro.plate('batch', x.shape[0]):
         pyro.sample('y', dist.Categorical(logits=logits), obs=y_true)

BTW your use of F.softmax() suggests you’re computing probs rather than logits. If you want to compute logits I think you’d instead either directly use the output of self.bl2 or use F.log_softmax().

1 Like

Hi @fritzo, that’s exactly what I wanted. Also thank you for pointing out the logits-probs mix-up.