Pytorch RNN lengths in cpu incompatibility with Elbo loss calculation

Hi!

I am just writing to see if you have any idea how to work around this issue https://github.com/pytorch/pytorch/issues/43227

The newest pytorch (py3.7_cuda10.2.89_cudnn7.6.5_0 ) has a strange behaviour for the RNN and requires the lengths to be on the cpu. This not compatible with the calculation of the error loss in pyro-ppl 1.5.1

File “/home/…/anaconda3/lib/python3.7/site-packages/pyro/infer/renyi_elbo.py”, line 150, in loss_and_grads
elbo_particle = elbo_particle + log_prob_sum.detach()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

which make sense, to have everything on the same device, but with the torch’s current change, it is not possible.
I have requested a solution to the pytorch people, because, I think this behaviour is undesired and inconvenient.

However, I am communicating it to you, to see if it makes sense to do something about it,

Thank you very much for all your work!

Hi @artistworking, I’d guess you could work around this by changing the device of distribution parameters or sample values outside of pyro.sample statements. Could you paste a bit of your model code around the pyro.sample site that leads to the tensor on the wrong device?

Yep,so it is called inside the guide as:

pack_and_padded_sequences = nn.utils.rnn.pack_padded_sequence(reversed_sequences,
sorted_sequences_lengths.cpu(), #pytorch recommended fix
batch_first=True)

And the pytorch error:

 File "/home/.../anaconda3/lib/python3.7/site-packages/torch/nn/utils/rnn.py", line 244, in pack_padded_sequence
    _VF._pack_padded_sequence(input, lengths, batch_first)
RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor
                               Trace Shapes:      

I have tried to transfer the pack_padded_sequences to the gpu with .cuda() but it did not take effect

Thanks! :slight_smile:

Hmm, could you paste the code up to and including the next pyro.sample() statement? I would think you could call .to(reversed_sequences.device) on some tensor or other just before the pyro.sample() statement, and then a properly located tensor would be stored in the Pyro device. I find it often works to change dtype or device between pyro.sample() statements, as long as I change back when communicating with pyro via pyro.sample().

Ok, that means , the entire guide

def guide(self, sorted_sequences,reversed_sequences,sorted_sequences_lengths,batch_mask, annealing_factor=1.0):
    pyro.module("VAEmodel",self)
    pyro.module("encoder", self.encoder)
    pyro.module("gruGUIDE",self.gruGUIDE)
    pyro.module("PositionalEncoding",self.PositionalEncoding)
    h_0_contig = self.h_0_GUIDE.expand(self.gruGUIDE.num_layers*2, sorted_sequences.size(0), self.gruGUIDE.hidden_size).contiguous()
    pack_and_padded_sequences = nn.utils.rnn.pack_padded_sequence(reversed_sequences, sorted_sequences_lengths,batch_first=True)
    with pyro.plate("data",sorted_sequences.size(0)): 
         gru_output, hidden_states = self.gruGUIDE(pack_and_padded_sequences, h_0_contig)
         z_loc, z_scale = self.encoder(gru_output[:,-1]) 
         with pyro.poutine.scale(scale=annealing_factor):
             if self.iafs.__len__() > 0:
                 z_dist = dist.TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                 pyro.sample("latentGitVectorSpace", z_dist) 
             else:
                 pyro.sample("latentGitVectorSpace", dist.Normal(z_loc, z_scale).to_event(1))

In case it’s relevant, the hidden state of the RNN is initialized as:

self.h_0_GUIDE = nn.Parameter(torch.randn(gru_hidden_dim_Guide), requires_grad=False)

Thanks!!!

Hmm, does the following work? Also can you debug to find the site["name"] at which Pyro errors, say by running under pdb?

  def guide(self, sorted_sequences,reversed_sequences,sorted_sequences_lengths,batch_mask, 
annealing_factor=1.0):
+     device = sorted_sequences.device
      pyro.module("VAEmodel",self)
      pyro.module("encoder", self.encoder)
      pyro.module("gruGUIDE",self.gruGUIDE)
      pyro.module("PositionalEncoding",self.PositionalEncoding)
      h_0_contig = self.h_0_GUIDE.expand(self.gruGUIDE.num_layers*2, sorted_sequences.size(0), self.gruGUIDE.hidden_size).contiguous()
      pack_and_padded_sequences = nn.utils.rnn.pack_padded_sequence(
          reversed_sequences,
-         sorted_sequences_lengths,
+         sorted_sequences_lengths.cpu(),
          batch_first=True)
+     pack_and_padded_sequences = pack_and_padded_sequences.to(device)
      with pyro.plate("data",sorted_sequences.size(0)): 
          gru_output, hidden_states = self.gruGUIDE(pack_and_padded_sequences, h_0_contig)
          z_loc, z_scale = self.encoder(gru_output[:,-1]) 
          with pyro.poutine.scale(scale=annealing_factor):
              if self.iafs.__len__() > 0:
                  z_dist = dist.TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                  pyro.sample("latentGitVectorSpace", z_dist) 
              else:
                  pyro.sample("latentGitVectorSpace", dist.Normal(z_loc, z_scale).to_event(1))

Ok, so change of events. I found which pyro.sample triggers the error and it’s in the model (everything works without this sample statement):

> de model(....):
>       with pyro.plate("data", sorted_sequences.size(0)): #where sorted_sequences is dim = [50,839,22]:
>                     ....more code.....
>                    inverse = pyro.sample("inverse", dist.HalfNormal(1).expand_by([scales.shape[1],scales.shape[2]]).to_event(2))#Needs to be [50,839,2]

I have already tried .to(device) :slight_smile: , and other distributions, same error.

My python debugger cannot handle giving me the information about the contents of site, so I had to print it…Under site[“name”] == “sample” , there are only 2 names: latentGitVectorSpace, inverse, but I guess we already solved that the problem is the inverse, hhehe

Thanks and sorry, I mixed 2 different errors. It’s an old script that I am updating hehe

@fritzo Do you have an idea on how to debug it to find out why that pyro.sample statement is not in the right device ? :grimacing: Thanks!

@artistworking it might be HalfNormal(1) creating a tensor on the wrong device (but only if you haven’t set the default device). How about this solution, ensuring that the HalfNormal is on the right device:

dist.HalfNormal(sorted_sequences.new_ones(()))

or equivalently

dist.HalfNormal(sorted_sequences.new_tensor(1.))

BTW Are you trying to jit compile? I’d suggest working at first without jit compiling (i.e. try with Trace_ELBO before JitTrace_ELBO). The jit treats torch.Size as a torch.Tensor and then you might need to worry about its device.

Oh my, Thanks!!! I would not have guessed any of that. I am using RenyiELBO with num particles=1.
So it works with RenyiELBO + dist.HalfNormal(sorted_sequences.new_tensor(1.)). It also works with my previous code + Trace_ELBO. So I guess RenyiELBO has some issue somewhere?,

But thanks a lot!!! Have a great day :smiley: