Question on plate

  • What tutorial are you running? LDA
  • What version of Pyro are you using? 0.3
  • Please link or paste relevant code, and steps to reproduce.

I have reproduced the guide found in the LDA tutorial at pyro.ai . In reference to the parameter topic_weights_posterior, shouldn’t there be different parameters for each document? Otherwise each document will have the same topic weights. What am I missing? In other words, I would have expected an assignment more like

topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_docs, args.num_topics) / args.num_topics,  
            constraint=constraints.positive)

to generate a 2D array of parameters.

Below is the code from Tutorial:

def parametrized_guide(predictor, data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics) / args.num_topics,
            constraint=constraints.positive)
    topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(args.num_topics, args.num_words) / args.num_words,
            constraint=constraints.positive)
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        if torch._C._get_tracing_state():
            counts = torch.eye(1024)[data[:, ind]].sum(0).t()
        else:
            counts = torch.zeros(args.num_words, ind.size(0))
            counts.scatter_add_(0, data[:, ind], torch.tensor(1.).expand(counts.shape))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))

topic_weights is the global variable; doc_topics is the local per-document variable.