Ok, so I took the standard SSVAE example in pyro/ss_vae_M2.py, and replaced the standard “encoder_y” network with my own. This is the guide network that produces the distribution over the “y” latent, which is the class of the MNIST digit. The guide and model code are the same as the example; I’m only changing the guide network for “y”. I have three versions:

Linear layers, with one hidden layer
class TempNet(nn.Module):
def __init__(self, n_input, n_hidden, n_output):
super(TempNet, self).__init__()
self.lin1 = nn.Linear(n_input, 500)
# self.lin1b = nn.Linear(500, 500)
self.lin2 = nn.Linear(500, n_output)
def forward(self, x):
x = x.view(x.shape[0], 28, 28).unsqueeze(1)
x = x.view(1, 28*28)
x = F.relu(self.lin1(x))
# x = F.relu(self.lin1b(x))
x = ClippedSoftmax(1e7, dim=1)(self.lin2(x))
return x

Extra hidden linear layer
class TempNet(nn.Module):
def __init__(self, n_input, n_hidden, n_output):
super(TempNet, self).__init__()
self.lin1 = nn.Linear(n_input, 500)
self.lin1b = nn.Linear(500, 500)
self.lin2 = nn.Linear(500, n_output)
def forward(self, x):
x = x.view(x.shape[0], 28, 28).unsqueeze(1)
x = x.view(1, 28*28)
x = F.relu(self.lin1(x))
x = F.relu(self.lin1b(x))
x = ClippedSoftmax(1e7, dim=1)(self.lin2(x))
return x

Extra hidden conv2d layer
class TempNetConv(nn.Module):
def __init__(self, n_input, n_hidden, n_output):
super(TempNetConv, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 3, padding=1)
self.lin1 = nn.Linear(n_input, 500)
self.lin2 = nn.Linear(500, n_output)
def forward(self, x):
x = x.view(x.shape[0], 28, 28).unsqueeze(1)
x = F.relu(self.conv1(x))
x = x.view(1, 28*28)
x = F.relu(self.lin1(x))
x = ClippedSoftmax(1e7, dim=1)(self.lin2(x))
return x
The linear networks have the view statements of the conv network to make it more even.
On my machine, 5 epochs using network 1 takes between 1:02 and 1:06 minutes. Network 2 takes between 1:05 and 1:09 minutes. Network 3 takes between 1:17 and 1:18 minutes.
So, there seems to be something different about the nn.Conv2d layer. Network 2 has many more parameters to learn, compared with network 3, but it’s a lot faster.
Any ideas on what might be causing this?
Thanks