Hi all,

I’m new to Pyro and I come up with a simple question: how to sample the generative process with multiple observations, like modeling D documents (1st obs) with associating L class labels (2nd obs).

Assume each document d is associated with a label l, and its topic distribution can be denoted as

\theta_{i}~Dir(\alpha)



The vanilla LDA can be modeled as follows, but I couldn’t come up with any ideas about how can I introduce this the labels into the sampling process.

def model(data=None, args=None, batch_size=None):
# Globals.
with pyro.plate("topics", args.num_topics):
topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
topic_words = pyro.sample("topic_words",
dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

# Locals.
with pyro.plate("documents", args.num_docs) as ind:
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})
data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
obs=data)



I firstly come up with an idea that maybe the

doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))


can be replaced by something like

docs_topics = [for i in range(args.num_robos)
pyro.sample("docs[{}]_topics".format(i),
dist.Dirichlet(topic_weights)
)
]


and chosen for doc d like docs_topics[i], but then I realized that all documents are taken as input to obs and the generative process does not distinguish between documents with different labels.