I’m looking to implement a “vanilla” seq2seq model for translation in Pyro (I want to extend it to variational nmt stuff but figured I need to get this working to even worry about that). Unfortunately, I’m not quite as comfortable to be sure I’ve implemented correctly and am looking for some feedback as currently my model doesn’t seem to learn to translate sequences in a simple scenario.
Here’s my model function, I’ve tried to follow the dmm tutorial example with the auxilliary model in the semi-supervised VAE for my purposes, the guide function is empty:
def model(self, x_sent, y_sent):
pyro.module("seq2seq", self)
#handles getting word embeddings and keeping pair data aligned
x_embeds, _, _, y_sent = self.x_embed(x_sent, y_sent) # convert list of strings =>PackedSequence(),
x_out, s_0 = self.encoder(x_embeds) # just a pytorch GRU
y_labels = self.y_embed.sentences2IndexesAndLens(y_sent) #convert sentence => word indexes
T_max = max([y[1] for y in y_labels]) # longest sequence in batch
y_labels, y_mask = self.y_embed.padAndMask(y_labels, batch_first=self.b_f)
if self.use_cuda:
y_labels = y_labels.cuda()
#Presumably this line make each data point independent of others
with pyro.plate('z_minibatch', len(x_sent)): # here is where I am completely loss
# unsqueeze 1st dim because 1st dim input is for hidden directions
s_0 = torch.cat([s_0[0], s_0[1]],dim=1)
#The bridge is supposedly a trick to
s_t = self.bridge(s_0).unsqueeze(0)
for t in range(0, T_max):
inputs = self.y_embed.getBatchEmbeddings(y_labels[:, t]).unsqueeze(1)
output, s_t = self.decoder(inputs, s_t)
output = self.emitter(output.squeeze(1))
l = y_labels[:,t]
pyro.sample('y_{}'.format(t),
dist.Categorical(logits=output).mask(y_mask[:,t]).to_event(1),#probs=F.softmax(output, dim=1)).mask(y_mask[:, t]).to_event(1),
obs=l)
The part I am least certain about is the pyro.sample statement in the plate, this…look close to what they do in the dmm model but i’m not sure if there’s just some glaring issue I can’t see.