Several questions about pyro minibatch SVI

Hi,

I am trying to implement BNN for regression that from the 10-dimensional outputs 3-dimensional prediction and I have few question. First, here is my model:

class BNNRegression(PyroModule):
    def __init__(self, in_dim=10, out_dim=3, hid_dim=100, prior_scale=1):
        super().__init__()

        zero = torch.tensor(0.).to(device)
        prior_scale = torch.tensor(prior_scale).to(device)
        self.one = torch.tensor(1).to(device)
        self.min_scale = torch.tensor(1e-8).to(device)
        
        self.scale_activation = nn.Softplus()
        self. activation = nn.ReLU()
        
        self.layer1 = PyroModule[nn.Linear](in_dim, hid_dim)  # Input to hidden layer
        self.layer2 = PyroModule[nn.Linear](hid_dim, out_dim) 
        self.layer3 = PyroModule[nn.Linear](hid_dim, out_dim) 


        self.layer1.weight = PyroSample(dist.Normal(zero, prior_scale).expand([hid_dim, in_dim]).to_event(2))
        self.layer1.bias = PyroSample(dist.Normal(zero, prior_scale).expand([hid_dim]).to_event(1))
        self.layer2.weight = PyroSample(dist.Normal(zero, prior_scale).expand([out_dim, hid_dim]).to_event(2))
        self.layer2.bias = PyroSample(dist.Normal(zero, prior_scale).expand([out_dim]).to_event(1))
        self.layer3.weight = PyroSample(dist.Normal(zero, prior_scale).expand([out_dim, hid_dim]).to_event(2))
        self.layer3.bias = PyroSample(dist.Normal(zero, prior_scale).expand([out_dim]).to_event(1))
        

    def forward(self, x, y=None):
        hidden = self.activation(self.layer1(x))
        loc = self.layer2(hidden)
        scale = torch.maximum(self.min_scale, self.scale_activation(self.layer3(hidden)))
        with pyro.plate("data", size = x.shape[0]):#, subsample_size = 1000, device=device): 
            obs = pyro.sample("obs", dist.Normal(loc, scale).to_event(1), obs=y)
        return loc

I am having a problem in when I try to perform mini batch SVI. Documentation (SVI Part II: Conditional Independence, Subsampling, and Amortization — Pyro Tutorials 1.8.6 documentation) says that it is possible to use subsamples inside plate to perform minibatching so I have few questions.

First question: Is subsampling using plate effectively same as sampling from the Pytorch train loader and then performing svi.step on minibatch?

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        svi.step(x, y)

Next, what is the difference if I try to use this code block from forward method inside or outside my plate?

        hidden = self.activation(self.layer1(x))
        loc = self.layer2(hidden)
        scale = torch.maximum(self.min_scale, self.scale_activation(self.layer3(hidden)))

Thanks in advance!

you can pipe in x/y from a data loader but you just need to make sure your plates have the right sizes specified, something like

num_train = len(train_loader)  # total number of training data points
subsample_size = len(x)  # size of mini-batch
with pyro.plate("data", size=num_train, subsample_size=subsample_size)
    ...
1 Like

for the second question: generally speaking pyro primitives like plate only affect other pyro primitives like sample. deterministic computations done in pure pytorch aren’t affected by pyro primitives so it doesn’t matter if it’s inside or outside a plate context manager.