MaskedDistribution

Hi everyone,
I implemented a sort-of Deep Markov Model (as in your tutorial) and used the MaskedDistribution to cope with sequences of different lengths in the batch. This is the code snippet:

for t in range(1, T_max + 1):
      k = pyro.sample(
           "obs_x_%d" % t,
            dist.OneHotCategorical(probs[:, t - 1, :]).mask(mini_batch_mask[:, t - 1 : t]),
            obs=one_hot(target[:, t - 1], self.embedding.num_embeddings),
      )

Let me suppose that T_max is the maximum sequence length in the batch (e.g. 41), probs is a three-dimensional tensor of shape [batch_size, max_seq_length, cat_probs] = [16, 41, 40], mini_batch_mask is a two-dimensional boolean vector of shape [batch_size, max_seq_length] = [16, 41], and finally target is a two-dimensional tensor of shape [16, 40].

From the code above, I would have expected that the shape of dist.OneHotCategorical(probs[:, t - 1, :]).mask(mini_batch_mask[:, t - 1 : t]) would have been [16, 40], but it returns [16, 16, 40] instead.

Am I wrong? My goal is to avoid, step by step, that the padding symbol (0) “pollutes” (to use the same words as in your tutorial) the model computation.

Thank you in advance

this looks like a pytorch indexing question?

do you mean the shape of the sample k is [16, 16, 40]?

do you maybe instead want mini_batch_mask[:, t - 1]?

Dear Martin,
thank you for your answer.

Yes, the shape of k is [16, 16, 40].

I actually already implemented the solution that you proposed (i.e. using mini_batch_mask[:, t - 1] instead), and it works. However, I would like to be sure that this way I’m able to mask from computation unwanted symbols (as e.g., the padding symbol).

Is then correct to use this code in place of the previous one?

for t in range(1, T_max + 1):
      k = pyro.sample(
           "obs_x_%d" % t,
            dist.OneHotCategorical(probs[:, t - 1, :]).mask(mini_batch_mask[:, t - 1]),
            obs=one_hot(target[:, t - 1], self.embedding.num_embeddings),
      )

Thank you again for your support and patience.