Latent Dirichlet Allocation Model

  • What tutorial are you running?
    https://pyro.ai/examples/lda.html
  • What version of Pyro are you using?
    pyro.version
    '1.4.0+2dfa8da0
  • Please link or paste relevant code, and steps to reproduce.

Based on the Pyro LDA tutorial, I changed the code to support varying number of words in documents and applied it to some dataset with 856 documents that contain abstracts of papers.

In the model, I changed with plate to for plate to support varying number of documents. Also, I removed triggering enumeration [pyro.sample("doc_topics_{}".format(doc), dist.Dirichlet(topic_weights),infer={"enumerate": "parallel"}) ] since with for plate it gives error.

   def model(data=None, num_words_per_doc=None, args=None):
       #Globals.
        with pyro.plate("topics", args.num_topics):
            topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
            topic_words = pyro.sample("topic_words",
                                      dist.Dirichlet(torch.ones(args.num_words) / args.num_words)
        # Changed here to from vector(with) to iteration to support varying number 
        # of words (num_words_per_doc) .
        # with pyro.plate("documents", args.num_docs) as ind:
        for doc in pyro.plate("documents", args.num_docs):
            doc_topics = pyro.sample("doc_topics_{}".format(doc), dist.Dirichlet(topic_weights))
            with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]):
                word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics))
                pyro.sample("doc_words_{}".format(doc), dist.Categorical(topic_words[word_topics]),
                                   obs=data[doc])
        return topic_weights, topic_words

Then, used same predictor for the guide

# We will use amortized inference of the local topic variables, achieved by a
# multi-layer perceptron. We'll wrap the guide in an nn.Module.
def make_predictor(args):
    layer_sizes = ([args.num_words] +
                   [int(s) for s in args.layer_sizes.split('-')] +
                   [args.num_topics])
    logging.info('Creating MLP with sizes {}'.format(layer_sizes))
    layers = []
    for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
        layer = nn.Linear(in_size, out_size)
        layer.weight.data.normal_(0, 0.001)
        layer.bias.data.normal_(0, 0.001)
        layers.append(layer)
        layers.append(nn.Sigmoid())
    layers.append(nn.Softmax(dim=-1))
    return nn.Sequential(*layers)

Changed the guide according to model since enumeration is removed, I added words_{} plate and slightly changed the predictor part to make it compatible with documents plate for loop. I also removed subsampling since it gives key error in the pyro.infer.Predictive while getting the posterior samples:

def parametrized_guide(predictor, data, num_words_per_doc, args):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics),
            constraint=constraints.positive)
    topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(args.num_topics, args.num_words),
            constraint=constraints.greater_than(0.5))
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    for doc in pyro.plate("documents", args.num_docs):
        # data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts =  torch.zeros(args.num_words, 1)
        for i in data[doc]: counts[i] += 1 
                    #    .scatter_add(0, data[doc], torch.ones(data[doc].shape)))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics_{}".format(doc), dist.Delta(doc_topics, event_dim=1))
        # added this part since 
        with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]):
            word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics))

In the main method, except data preprocessing and using Trace_ELBO instead of TraceEnum_ELBO [since I removed enumeration] basically it is same as the tutorial.

def main(args):
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)
    documents = pd.read_csv('../Data/ToyExample/papers2017.csv', error_bad_lines=False)
    processed_docs = documents['abstract'].map(preprocess)
    vocabulary = gensim.corpora.Dictionary(processed_docs)
    vocabulary.filter_extremes(no_below=10, no_above=0.5)
data = [torch.tensor(list(filter(lambda a: a!=-1,vocabulary.doc2idx(doc))),dtype=torch.int64) for doc in processed_docs]
    N = list(map(len, data))
    args.num_words = len(vocabulary)
    args.num_docs = len(data)   
map(lambda x:x[1], bow_corpus[1])

# We'll train using SVI.
    logging.info('Training on {} documents'.format(args.num_docs))
    predictor = make_predictor(args)
    guide = functools.partial(parametrized_guide, predictor)
    Elbo = JitTraceEnum_ELBO if args.jit else Trace_ELBO
    elbo = Elbo(max_plate_nesting=2)
    optim = ClippedAdam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)
    logging.info('Step\tLoss')
for step in range(args.num_steps):
        loss = svi.step(data, N, args=args)
if step % 10 == 0:
            logging.info('{: >5d}\t{}'.format(step, loss))
    loss = elbo.loss(model, guide, data,N, args=args)
    logging.info('final loss = {}'.format(loss))
num_samples=100
predictive = Predictive(model, guide=guide, num_samples=num_samples)
samples = predictive(data, N, args=args)

When I get the samples, I see that doc_topics_{} are almost the same for all documents even though documents are quite different. And with library like Gensim.model.lda, I was able to produce meaningful topic words and topic distributions for this dataset. The model and guide looks correct to me even though I had to remove enumeration (but instead readded nested plate into the guide). I tried different learning rates(0.005 to 0.5) and the number of steps for SVI up to 5000 but still have the same results. I would appreciate any feedback on that whether the model or guide is wrong.
Thank you in advance.