- What tutorial are you running? LDA
- What version of Pyro are you using? 0.3
- Please link or paste relevant code, and steps to reproduce.
I have reproduced the guide found in the LDA tutorial at pyro.ai . In reference to the parameter topic_weights_posterior
, shouldn’t there be different parameters for each document? Otherwise each document will have the same topic weights. What am I missing? In other words, I would have expected an assignment more like
topic_weights_posterior = pyro.param(
"topic_weights_posterior",
lambda: torch.ones(args.num_docs, args.num_topics) / args.num_topics,
constraint=constraints.positive)
to generate a 2D array of parameters.
Below is the code from Tutorial:
def parametrized_guide(predictor, data, args, batch_size=None):
# Use a conjugate guide for global variables.
topic_weights_posterior = pyro.param(
"topic_weights_posterior",
lambda: torch.ones(args.num_topics) / args.num_topics,
constraint=constraints.positive)
topic_words_posterior = pyro.param(
"topic_words_posterior",
lambda: torch.ones(args.num_topics, args.num_words) / args.num_words,
constraint=constraints.positive)
with pyro.plate("topics", args.num_topics):
pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
# Use an amortized guide for local variables.
pyro.module("predictor", predictor)
with pyro.plate("documents", args.num_docs, batch_size) as ind:
# The neural network will operate on histograms rather than word
# index vectors, so we'll convert the raw data to a histogram.
if torch._C._get_tracing_state():
counts = torch.eye(1024)[data[:, ind]].sum(0).t()
else:
counts = torch.zeros(args.num_words, ind.size(0))
counts.scatter_add_(0, data[:, ind], torch.tensor(1.).expand(counts.shape))
doc_topics = predictor(counts.transpose(0, 1))
pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))