Variable length sequences and masking

I am trying to allow variable length sequences in a language model. I think the right way to do this is using masks, but I currently get an error. My code looks like this:

class Seq2Hist(torch.nn.Module):
    def __init__(self, D_in, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(Seq2Hist, self).__init__()
        self.num_words = D_out

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        out = torch.zeros(self.num_words + 1, x.shape[1]).scatter_add(
            0, x, torch.ones(x.shape)
        ).transpose(1, 0)
        return out[:, :-1]


class MLP(nn.Module):
    def __init__(self, args, eps=1):
        super(MLP, self).__init__()
        self.eps = eps
        
        self.seq2hist = Seq2Hist(args.num_words_per_doc, args.num_words)
        self.prob_layer = nn.Linear(args.num_words, 2)
        nn.init.xavier_uniform_(self.prob_layer.weight)
        self.prob_scale_layer = nn.Linear(args.num_words, 1)
        nn.init.xavier_uniform_(self.prob_scale_layer.weight)
        self.anneal_floor = 0
        
    # forward propagate input
    def forward(self, X):
        self.anneal_floor += 1
        z = self.seq2hist(X)
        background_prob = nn.Softmax(dim=-1)(self.prob_layer(z))
        background_prob_scale = nn.ReLU()(self.prob_scale_layer(z)) + self.eps / self.anneal_floor
        background_prob_prior = background_prob * background_prob_scale
         
        return background_prob_prior

    
def model(data=None, mask=None, args=None, batch_size=None, annealing_factor=1):
    with poutine.scale(None, annealing_factor):
        # Globals.
        topic_weights = pyro.sample(
                "topic_weights", dist.Dirichlet(10 * torch.ones(args.num_topics))
        )
        with pyro.plate("topics", args.num_topics + 1):
            topic_words = pyro.sample(
                "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words),
            )

    # Locals.
    with pyro.plate("documents", data.size(1) if data is not None else args.num_docs):
        with poutine.scale(None, annealing_factor):
            doc_background_prob = pyro.sample("doc_background_prob", 
                                              dist.Beta(10 * torch.ones(1), 10 * torch.ones(1))
                                             ).unsqueeze(-1)

            doc_topic = pyro.sample("doc_topic", dist.Categorical(topic_weights), 
                                   infer={"enumerate": "parallel"}) + 1

        with pyro.plate("words", args.num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            with poutine.mask(None, mask):
                word_background_ind = pyro.sample(
                    "word_background_ind",
                    dist.Bernoulli(doc_background_prob.squeeze()),
                    infer={"enumerate": "parallel"}
                )
                topic_ind = (doc_topic * word_background_ind).type(torch.LongTensor)
                data = pyro.sample(
                    "doc_words", dist.Categorical(Vindex(topic_words)[topic_ind]), 
                    obs=data, obs_mask=mask
                )

    return topic_weights, topic_words, doc_topic, data

def parametrized_guide(predictor, data, mask, args, batch_size=None, annealing_factor=1):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        lambda: torch.ones(args.num_topics),
        constraint=constraints.positive,
    )
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics + 1, args.num_words),
        constraint=constraints.greater_than(1e-8),
    )
    with poutine.scale(None, annealing_factor):
        pyro.sample("topic_weights", dist.Dirichlet(topic_weights_posterior))
        with pyro.plate("topics", args.num_topics + 1):
            pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for continuous local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", data.size(1)):
        background_prob_posterior = predictor(data)
        
        with poutine.scale(None, annealing_factor):
            pyro.sample("doc_background_prob", 
                        dist.Beta(background_prob_posterior[:, 0], 
                                  background_prob_posterior[:, 1]))

The sequences are padded were originally padded with a -1 but this gave an out of support error. So I switched to padding the sequences with a 0, but now the topics all put high probability on 0, so clearly I’m not masking correctly here, or maybe masking isn’t the right approach…

Any assistance would be helpful!

I just wanted to bump this thread. I think this is a problem related to the length masking here: Example: Hidden Markov Models — Pyro Tutorials 1.8.4 documentation

But I’m not sure. Guidance much appreciated!

It’s hard to say what could be going wrong, especially without knowing how you’re computing mask or performing inference, but one thing that looks strange in your code is the use of the obs_mask argument for doc_words - you do not need to do this on top of poutine.mask. You should also assign explicit dimensions to your plates (via the dim= keyword argument) and assert that your mask shape is correct with respect to those dimensions.

More broadly, consider copying the usage of poutine.mask for variable-length sequences in our HMM example models as closely as possible, since we know those are correct. I would also strongly encourage you to read the ProdLDA tutorial, which compresses documents into count vectors and uses a multinomial likelihood (avoiding this problem entirely) instead of the explicit word-level approach you are taking here.

i didnt think i could use the multinomial likelihood due to inhomogenous counts! that’s why I moved away from that implementation, which i had working otherwise. ok I can switch to that. thank you!