Specifying a vanilla seq2seq in Pyro

I’m looking to implement a “vanilla” seq2seq model for translation in Pyro (I want to extend it to variational nmt stuff but figured I need to get this working to even worry about that). Unfortunately, I’m not quite as comfortable to be sure I’ve implemented correctly and am looking for some feedback as currently my model doesn’t seem to learn to translate sequences in a simple scenario.

Here’s my model function, I’ve tried to follow the dmm tutorial example with the auxilliary model in the semi-supervised VAE for my purposes, the guide function is empty:

def model(self, x_sent, y_sent):
    pyro.module("seq2seq", self)
    #handles getting word embeddings and keeping pair data aligned
    x_embeds, _, _, y_sent = self.x_embed(x_sent, y_sent) # convert list of strings =>PackedSequence(),
    x_out, s_0 = self.encoder(x_embeds) # just a pytorch GRU

    y_labels = self.y_embed.sentences2IndexesAndLens(y_sent) #convert sentence => word indexes

    T_max = max([y[1] for y in y_labels]) # longest sequence in batch
    y_labels, y_mask = self.y_embed.padAndMask(y_labels, batch_first=self.b_f) 
    if self.use_cuda:
        y_labels = y_labels.cuda()
    #Presumably this line make each data point independent of others
    with pyro.plate('z_minibatch', len(x_sent)): # here is where I am completely loss
        
        # unsqueeze 1st dim because 1st dim input is for hidden directions
        s_0 = torch.cat([s_0[0], s_0[1]],dim=1)
        #The bridge is supposedly a trick to 
        s_t = self.bridge(s_0).unsqueeze(0)
        for t in range(0, T_max):
            inputs = self.y_embed.getBatchEmbeddings(y_labels[:, t]).unsqueeze(1) 
            output, s_t = self.decoder(inputs, s_t)
            output = self.emitter(output.squeeze(1))
            l = y_labels[:,t]
            pyro.sample('y_{}'.format(t),
                        dist.Categorical(logits=output).mask(y_mask[:,t]).to_event(1),#probs=F.softmax(output, dim=1)).mask(y_mask[:, t]).to_event(1),
                        obs=l)

The part I am least certain about is the pyro.sample statement in the plate, this…look close to what they do in the dmm model but i’m not sure if there’s just some glaring issue I can’t see.

You might want to look at the shapes and values of output to see if they look weird, e.g. they’re all the same along a dimension, and compare them to the shapes and values with the plate removed. All that shape manipulation happening in the model looks suspicious; you should prefer indexing from right to left using ellipses (...) and negative indices (unsqueeze(-2) vs unsqueeze(0), etc.). See the Pyro tensor shape tutorial for more info.

1 Like