Are Pytorch convolution operations slow in Pyro?


#1

Hi, I’m trying to adapt the SS-VAE example to do something slightly different, and I’m finding the learning process to be much slower when I add a nn.Conv2D layer to one of my guide networks. More specifically, I’m trying to include a 2D array of Bernoulli distributions as latent variables.

I think this might be slow, because Pyro will sum over all of these Bernoullis, because they are discrete, and I’ve specified “enum_discrete=True”, but it seems fast enough when the corresponding guide doesn’t include a conv2d layer.

For example, without the conv2d, 5 epochs takes about 3.5 minutes. With the conv2d layer, it takes about 4 minutes.

Is this expected?


#2

Hmm it’s hard to say. Could you paste simplified versions of the model and guide code?


#3

Ok, so I took the standard SS-VAE example in pyro/ss_vae_M2.py, and replaced the standard “encoder_y” network with my own. This is the guide network that produces the distribution over the “y” latent, which is the class of the MNIST digit. The guide and model code are the same as the example; I’m only changing the guide network for “y”. I have three versions:

  1. Linear layers, with one hidden layer

    class TempNet(nn.Module):

     def __init__(self, n_input, n_hidden, n_output):
         super(TempNet, self).__init__()
         self.lin1 = nn.Linear(n_input, 500)
         # self.lin1b = nn.Linear(500, 500)
         self.lin2 = nn.Linear(500, n_output)
    
     def forward(self, x):
    
         x = x.view(x.shape[0], 28, 28).unsqueeze(1)
         x = x.view(-1, 28*28)
    
         x = F.relu(self.lin1(x))
         # x = F.relu(self.lin1b(x))
         x = ClippedSoftmax(1e-7, dim=1)(self.lin2(x))
         return x
    
  2. Extra hidden linear layer

    class TempNet(nn.Module):

     def __init__(self, n_input, n_hidden, n_output):
         super(TempNet, self).__init__()
         self.lin1 = nn.Linear(n_input, 500)
         self.lin1b = nn.Linear(500, 500)
         self.lin2 = nn.Linear(500, n_output)
    
     def forward(self, x):
    
         x = x.view(x.shape[0], 28, 28).unsqueeze(1)
         x = x.view(-1, 28*28)
    
         x = F.relu(self.lin1(x))
         x = F.relu(self.lin1b(x))
         x = ClippedSoftmax(1e-7, dim=1)(self.lin2(x))
         return x
    
  3. Extra hidden conv2d layer

    class TempNetConv(nn.Module):

     def __init__(self, n_input, n_hidden, n_output):
         super(TempNetConv, self).__init__()
         self.conv1 = nn.Conv2d(1, 1, 3, padding=1)
         self.lin1 = nn.Linear(n_input, 500)
         self.lin2 = nn.Linear(500, n_output)
    
     def forward(self, x):
    
         x = x.view(x.shape[0], 28, 28).unsqueeze(1)
         x = F.relu(self.conv1(x))
         x = x.view(-1, 28*28)
    
         x = F.relu(self.lin1(x))
         x = ClippedSoftmax(1e-7, dim=1)(self.lin2(x))
         return x
    

The linear networks have the view statements of the conv network to make it more even.

On my machine, 5 epochs using network 1 takes between 1:02 and 1:06 minutes. Network 2 takes between 1:05 and 1:09 minutes. Network 3 takes between 1:17 and 1:18 minutes.

So, there seems to be something different about the nn.Conv2d layer. Network 2 has many more parameters to learn, compared with network 3, but it’s a lot faster.

Any ideas on what might be causing this?

Thanks


#4

Hmm those differences are pretty small, so it’s tough to say. I’ll make a wild guess that maybe the input is not contiguous so that the convolution operations need to make a copy before operating? You might try inserting assert x.is_contiguous to see if things are getting messed up only when running under Pyro.


#5

The differences are small in this case, but this is a very contrived model and dataset. Typically there will be several conv layers, with many more filters, and the images will be much larger than 28x28. I don’t know if the time difference grows with layers, filters, or image size, but if it did, it might be a big problem.

I’ll try the contiguous thing. And maybe someone else can verify that this is happening? Has anyone else tried Pyro on larger CNNs with larger images?


#6

Hi Paul,
I am curious, did you resolve the issue?


#7

I didn’t resolve the issue, but I continued with my model and I don’t recall having any more speed issues. So, maybe I inadvertently fixed it, or maybe I just got used to it.

Are you having similar problems?


#8

Just started converting the vae code from using linear layers to conv2d layers. do you think you can share your code? that’d be much appreciated