How to set batchnorm to evaluation mode in BNN when using SVI


I am tring to sample multiple models for inference from approximated posterior distributions (i.e the guide) when using SVI. There are batchnorm ops in my model, then what should i do with the model (is model.eval() enough, or something else to keep the batchnorm running in the right way).

Here are part of codes in model and evaluation:

class MLP(torch.nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_dims, hidden_dims)
        self.bn1 = torch.nn.BatchNorm1d(hidden_dims)
        self.act1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_dims, hidden_dims)
        self.bn2 = torch.nn.BatchNorm1d(hidden_dims)
        self.act2 = torch.nn.ReLU()
        self.fc3 = torch.nn.Linear(hidden_dims, output_dims)
        self.act3 = torch.nn.Sigmoid()

    def forward(self, x):
        output = self.act3(self.fc3(self.act2(self.bn2(self.fc2(self.act1(self.bn1(self.fc1(x))))))))
        return output
sampled_models = [guide(None, None).eval() for _ in range(num_models)] # is .eval() needed here
pred_y = [model(test_x).data.cpu().numpy() for model in sampled_models]

Or should i define distributions in guide to approximate the learnable running_mean and running_var in BN?

Best regards!

Maybe you can use the same pattern as in scanvi example.

Thanks for your reply!