Svi_step error when using poutine.mask in hierarchical model

Hi, I’m new to pyro so many thanks for comments.
This is an attempt to understand effect handlers via a simple modification of the amortized lda tutorial to ignore part of the data tensor.

Suppose each of the 1000 docs of length 64 is actually padded and we have:

doclengths = torch.randint(low=20,high=40,size=(1000,1))

Then update the model so that poutine.mask ignores doclengths[ind]+ for each doc:

def model(data=None, doclengths=doclengths,args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample("topic_words",
                                  dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        if data is not None:
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (args.num_words_per_doc, args.num_docs)
            data = data[:, ind]
            doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
            x = ind.view(ind.shape[0],1)
            y = torch.arange(args.num_words_per_doc).unsqueeze(-2)
            _,b=torch.broadcast_tensors(x,y)
             
            with pyro.plate("words", args.num_words_per_doc):
                # The word_topics variable is marginalized out during inference,
                # achieved by specifying infer={"enumerate": "parallel"} and using
                # TraceEnum_ELBO for inference. Thus we can ignore this variable in
                # the guide.
                with poutine.mask(mask=(b < doclengths[ind]).unsqueeze(-1)):
                # with poutine.mask(mask=m < doclengths):
                    word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                            infer={"enumerate": "parallel"})
                    data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
                    obs=data)
        else:
            doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
            with pyro.plate("words", args.num_words_per_doc):
                # The word_topics variable is marginalized out during inference,
                # achieved by specifying infer={"enumerate": "parallel"} and using
                # TraceEnum_ELBO for inference. Thus we can ignore this variable in
                # the guide.
                word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                        infer={"enumerate": "parallel"})
                data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),obs=data)

    return topic_weights, topic_words, data

This throws an error:

Exception has occurred: RuntimeError
The size of tensor a (32) must match the size of tensor b (8) at non-singleton dimension 0
  File "/home/au/code/pyro_examples/lda_amortized_ragged/__main__.py", line 149, in main
    loss = svi.step(data, args=args, batch_size=args.batch_size)
  File "/home/au/code/pyro_examples/lda_amortized_ragged/__main__.py", line 169, in <module>
    main(args)

Regarding:

loss = svi.step(data, args=args, batch_size=args.batch_size)
1 Like

Hi @mkarikom, diagnosing shape errors can be tricky. It’s hard for us to be more helpful without a full stack trace and complete runnable example script that reproduces the error, but here are some general tips drawn from the tensor shapes and enumeration tutorials:

  1. Assign dimensions to plates manually with the dim= keyword argument to pyro.plate:
...
docs_plate = pyro.plate("documents", args.num_docs, dim=-1)
with docs_plate as ind:
    ...
  1. Use shape assertions aggressively on intermediate values throughout your model during model development. For example, in this case you might want to add a shape assertion about your mask, or about various distribution parameters:
doclengths_mask = (b < doclengths[ind]).unsqueeze(-1)
assert doclengths_mask.shape == (args.num_words_per_doc, args.num_docs)
with poutine.mask(mask=doclengths_mask):
    ...
  1. If you are using pyro.plate to subsample data and local variables, use the pyro.subsample primitive to safely and automatically apply plate indices to values:
with docs_plate:
    ...
    with words_plate:
        ...
        obs_data = pyro.subsample(data, event_dim=0)
        data = pyro.sample("doc_words", ..., obs=obs_data)

Also, while the example you pointed to is a nice conceptual illustration of enumeration and amortized inference with Pyro, using dense linear algebra and masking in LDA is going to scale poorly in the minimum and maximum document lengths and vocabulary size (and we should update the description of that example to make this clear!).

If you are interested in applying LDA-like models to a real problem, I recommend reading this tutorial for a more practical approach to topic modeling in Pyro.

1 Like

Thanks @eb8680_2, as you say the second tutorial seems preferable for many reasons (both in terms of efficiency and the convenience of the histogram matrix for a portable bag-of-words representation).

Regarding efficiency, I can see how marginalizing over the token-wise assignments (eq 4 in the original paper) precludes the ragged data issue since the likelihood no longer has dirichlet-categoricals over individual tokens.

However, it looks like the second tutorial utilizes Laplace approximation as in the original paper, while the first tutorial introduces an asymmetric Dirichlet prior over the topic weights similar to Wallach et al 2009 by sampling concentrations from a Gamma:

topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))

Do you think there is an advantage to the Laplace approximation? Please note, I didn’t quite grasp the reparameterization routine for the above as described in Jankowiak et al.

Thanks!

The log-Normal parametrization in the ProdLDA tutorial might be more numerically stable if the number of topics is very large or the true posterior is highly skewed, but I think the main reason that parametrization is used is that the paper it was based on was written before reparametrized samplers for Gamma and Dirichlet distributions were available and the person who contributed the Pyro tutorial decided to stick closely to the details in that paper.

It should be pretty straightforward to modify the model in the ProdLDA tutorial to use the Gamma-Dirichlet formulation from the first example. Note that you’ll need to add a topic_weights site to the guide as well as the model, as in the guide from the first example.

People ask about versions of LDA pretty frequently, so if you get that working we’d definitely welcome any pull requests with new or updated examples!

1 Like

Sounds great, thanks.