Considering a Bayesian neural network for classification in Pyro like so:
class BayesianLinear(PyroModule):
def __init__(self, in_size: int, out_size: int):
super().__init__()
self.bias = PyroSample(
prior=dist.LogNormal(0, 1).expand(torch.Size([out_size])).to_event(1)
)
self.weight = PyroSample(
prior=dist.Normal(0, 1).expand(torch.Size([in_size, out_size])).to_event(2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.bias + x @ self.weight
class BNNClassifier(PyroModule):
def __init__(self, input_size: int, latent_size: int, output_size: int):
super().__init__()
self.bl1 = BayesianLinear(input_size, latent_size)
self.bl2 = BayesianLinear(latent_size, output_size)
def forward(self, x: torch.Tensor, y_true: torch.Tensor | None = None):
logits = F.softmax(self.bl2(F.leaky_relu(self.bl1(x))), dim=1)
with pyro.plate('batch', x.shape[0]):
pyro.sample('y', dist.Categorical(logits=logits), obs=y_true)
I know that after training with SVI and some AutoGuide, I can use the Predictive module
to sample parameters and target variable y corresponding to some test input x, e.g. like so:
pyro.infer.predictive.Predictive(..., num_samples=1000)(x_test)['y']
This gives me a point prediction tensor of shape [batch_size, num_samples].
What is the best way to obtain the logits for each sampled point estimator model
(a probability tensor of shape [batch_size, num_samples, num_classes]) instead?
I guess one could take the sampled model parameters from Predictive,
load them into a corresponding Pytorch model and make a forward pass using input x.
But this seems kind of like a workaround.