 # Modify LDA (Latent Dirichlet Allocation) on Binary Observation

Hi, I am new to pyro.

I am following your tutorial: https://pyro.ai/examples/lda.html, which is very clear and easy to read.
I want to apply it on binary matrix, where “num-of-words-per-document” is fixed and the observations (word) are drawn from beta-bernoulli distribution instead of dirichlet-categorical distribution.

Let K to be the number of topics, S to be the volumne of vocabulary, L to be the number of words in each document. In the simplest case, L is fixed and L=S.
Then,

1. for the original case where each word is drawn from Categorical Distribution, topic-word prior has K parameterss β. The length of each β is S. So the shape of parameter B=[β1,β2,…, βk] is (K, S).
2. for the case of binary observations where each entry is drawn from Bernoulli Distribution, topic -word prior should have K *S pairs of Beta parameters (α0，α1). Referring to the paper Rlda, each φcs for each entry in the binary case is drawn from a beta distribution while φc for each cluster in the multinomial case is drawn from dirichlet distrbution.

So I modify your codes to be like this (only the for generative part):

def model(data=None, args=None, batch_size=None):
# Globals.
with pyro.plate(“topics”, args.num_topics):
topic_weights = pyro.sample(“topic_weights”, dist.Gamma(1. / args.num_topics, 1.))
print(“topic weights”)
print(topic_weights.shape)
# NOTE: phi prior
# topic_words = pyro.sample(“topic_words”,
# dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
topic_words = pyro.sample(“topic_words”, dist.Beta(
torch.ones(args.num_words)*0.5, torch.ones(args.num_words)*0.5)
)
print(“topic words”)
print(topic_words.shape)

``````# Locals.
with pyro.plate("documents", args.num_docs) as ind:
if data is not None:
with pyro.util.ignore_jit_warnings():
assert data.shape == (args.num_words_per_doc, args.num_docs)
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})
# NOTE: phi prior
# data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
#    obs=data)
data = pyro.sample("doc_words", dist.Bernoulli(topic_words[word_topics]),
obs=data)

return topic_weights, topic_words, data
``````

I tested this part of codes but got shape mismatch problem (# of topics=8, # of words=64, # of documents=1000):

Could you help me figure out the error I’ve made? Thanks very much!

please read this tutorial and see if that answers your question.

Ok, thanks for your soon reply.

Hi, thanks for your tutorial. I have corrected the part of generative process and got the results I want.
The corrected codes are shown below:

``````# def model(data=None, args=None, batch_size=None):
# Globals.
with pyro.plate("topics", args.num_topics, dim=-1):
topic_weights = pyro.sample(
"topic_weights", dist.Gamma(1. / args.num_topics, 1.))

# NOTE: beta prior for topic words
with pyro.plate("words", args.num_words_per_doc):
topic_words = pyro.sample("topic_words", dist.Beta(
torch.tensor([0.5]), torch.tensor([0.5])))
assert topic_words.shape==(args.num_words_per_doc, args.num_topics)

# Locals.
with pyro.plate("documents", args.num_docs) as ind:
if data is not None:
with pyro.util.ignore_jit_warnings():
assert data.shape == (args.num_words_per_doc, args.num_docs)
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})
## word_topics (num_words=64, num_docs=1000)
# print(word_topics)
# NOTE: bernoulli likelihood
word_indexes=torch.arange(0,args.num_words_per_doc).unsqueeze(1).repeat(1, args.num_docs)
# print(topic_words[word_indexes, word_topics])
data = pyro.sample("doc_words", dist.Bernoulli(topic_words[word_indexes, word_topics]),
obs=data)

return topic_weights, topic_words, data
``````

And the generateive results look like the screenshot below, num_docs=1000, num_words=64, num_topics=8: However, errors occurred when I did the inference. Could I know how to modify the inference part according to the generative process above?

Thanks! Looking forward to your suggestions!

hi can you please specify in detail all the places you changed the model?

Sure @martinjankowiak. Actually I just modified two places of original model. I didn’t change the process for sampling topics for each document, I only changed those which are relevant to sampling word (binary value for my case) conditional on topic.

C: number of topics, S: number of words

For place 1, I just change the prior distribution of `topic_words`, from dirichlet to beta distribution. Referring to the paper Rlda, `topic_words` have `C` parameters `β` (`β` is a vector with length of `S`) if using dirichlet distribution but have `C*S` pairs of parameters `(α0, α1)` if using beta distribution. So I think it need to define `topic_words` under the “words” plate so that `topic_words` are sampling from number of `C*S` beta distributions.

In place 2, words are sampling from prior probability `φ`, referred to `topic_words` in codes. For dirichlet prior, only topic assigned to each word is required to get the prior probablity `φ_c`, corresponding to `topic_words[word_topics]` in codes. But for beta prior, both of the topic assigned to each word and the position of each word are needed to get the prior probability `φ_cs`, corresponding to `topic_words[word_indexes, word_topics]` in codes. `word_indexes` are actually the same as indexes of words in the `word_topic` matrix.

## Original Codes

``````def model(data=None, args=None, batch_size=None):
# Globals.
with pyro.plate("topics", args.num_topics):
topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
# Modified place 1---------------------------------------------------------------------------------------
topic_words = pyro.sample("topic_words",
dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
# ------------------------------------------------------------------------------------------------------------

# Locals.
with pyro.plate("documents", args.num_docs) as ind:
if data is not None:
with pyro.util.ignore_jit_warnings():
assert data.shape == (args.num_words_per_doc, args.num_docs)
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})
# Modified place 2---------------------------------------------------------------------------------------
data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
obs=data)
# ------------------------------------------------------------------------------------------------------------

return topic_weights, topic_words, data
``````

## Modified Codes

``````def model(data=None, args=None, batch_size=None):
# Globals.
with pyro.plate("topics", args.num_topics, dim=-1):
topic_weights = pyro.sample(
"topic_weights", dist.Gamma(1. / args.num_topics, 1.))

# NOTE: beta prior for topic words
# Modified place 1---------------------------------------------------------------------------------------
with pyro.plate("words", args.num_words_per_doc):
topic_words = pyro.sample("topic_words", dist.Beta(
torch.tensor([0.5]), torch.tensor([0.5])))
assert topic_words.shape==(args.num_words_per_doc, args.num_topics)
# ------------------------------------------------------------------------------------------------------------

# Locals.
with pyro.plate("documents", args.num_docs) as ind:
if data is not None:
with pyro.util.ignore_jit_warnings():
assert data.shape == (args.num_words_per_doc, args.num_docs)
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})

# Modified place 2---------------------------------------------------------------------------------------
# NOTE: bernoulli likelihood
word_indexes=torch.arange(0,args.num_words_per_doc).unsqueeze(1).repeat(1, args.num_docs)
# print(topic_words[word_indexes, word_topics])
data = pyro.sample("doc_words", dist.Bernoulli(topic_words[word_indexes, word_topics]),
obs=data)
# ------------------------------------------------------------------------------------------------------------

return topic_weights, topic_words, data

``````

Hi, @martinjankowiak. I add detials in the next post. Look forward to your suggestions. Thank!

i’m sorry but this is quite complicated to understand. to make it easier to understand can you please do the following?

for each sample statement add an `assert` statement afterwards so that the shape of every random variable is completely clear in a human intelligible way. something like

`assert topic_words.shape == (args.num_words_per_doc, something_else_i_dont_know)`

does your model run on its own?
`model(data=data, args=args, batch_size=123)`?

Solved, thanks. I don’t know why my data which is drawn from bernoulli distribution is float32 type. I just converted it into int64 to fix the bug.