Problem about "Multiple observations"

Hi all,

I’m new to Pyro and I come up with a simple question: how to sample the generative process with multiple observations, like modeling D documents (1st obs) with associating L class labels (2nd obs).

Assume each document d is associated with a label l, and its topic distribution can be denoted as

\theta_{i}~Dir(\alpha)

The vanilla LDA can be modeled as follows, but I couldn’t come up with any ideas about how can I introduce this the labels into the sampling process.

def model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample("topic_words",
                                  dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
        with pyro.plate("words", args.num_words_per_doc):
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
                               obs=data)

    return topic_weights, topic_words, data

I firstly come up with an idea that maybe the

doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))

can be replaced by something like

docs_topics = [for i in range(args.num_robos) 
                   pyro.sample("docs[{}]_topics".format(i), 
                                  dist.Dirichlet(topic_weights)
                     )
              ]

and chosen for doc d like docs_topics[i], but then I realized that all documents are taken as input to obs and the generative process does not distinguish between documents with different labels.

Could someone please help me to figure this problem out?