Relation between model and guide


#1

Hi,

I do not fully understand the relationship between model and guide. Sticking with the Gaussian Mixture model, the model has the following sample statements (see below). Looking at the guide section in the Examples (pyro.ai), one finds the statement: “Since the guide is an approximation to the posterior …, the guide needs to provide a valid joint probability density over all the latent random variables in the model. Recall that when random variables are specified in Pyro with the primitive statement pyro.sample() the first argument denotes the name of the random variable. These names will be used to align the random variables in the model and guide.”

The model has the statement:

weights = pyro.sample(‘weights’, dist.Dirichlet(0.5 * torch.ones(K)))

while the guide has no sample with the name ‘weights’. So there is a lack of alignment between model and guide. On the other hand, the sample “assignment” is contained in both guide and sample.

My question: When can certain sample names be omitted from the guide? I simply do not get it yet.

Code:

@config_enumerate(default='parallel')
@poutine.broadcast
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.iarange('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.iarange('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

while the guide is:

    @config_enumerate(default="parallel")
    @poutine.broadcast
    def full_guide(data):
        # Global variables.
        with poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.
            global_guide(data)

        # Local variables.
        with pyro.iarange('data', len(data)):
            assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
                                          constraint=constraints.unit_interval)
            pyro.sample('assignment', dist.Categorical(assignment_probs))

#2

ok, there are a few things going on here, some of which are being encapsulated which might be contributing to your confusion. in general, all your non-observed sample statements in the model should have a corresponding sample statement in the guide. what you’re missing in the snippet above is

global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']))

this statement is sampling a delta distribution for weights, locs, scale under the hood. this can be equivalently written as

def guide(data):
  auto_weights = pyro.param('auto_weights', ...)
  pyro.sample('weights', dist.Delta(auto_weights))
  ...

and similarly for locs and scale. you can play around with this by removing global_guide(data) and pyro will throw a warning that you have random variables in your model missing in your guide.


#3

Excellent! That explains a lot. I will reread some of the tutorials with this in mind.


#4

Hi,

Looking more closely at the model for the gmm, in light of what you wrote, I notice that ‘assignment’ is not referenced. And yet, I find the lines in In [5]: of the GMM notebook:

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, global_guide, optim, loss=elbo)

The SVI is expected to work, and yet, there is no ‘assignment’ variable defined in the guide, and ‘assignment’ has no constraints on it.

What am I missing? Thanks.

I might understand. In the model, one finds:

assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

Variable “assignment” depends on “weights”, and “weights” is included in the guide. So by default the guide will use the same formula for “assignment” as in the model? Is this true?

If true, why wouldn’t there be a default guide that provides missing variables with the definitions found in the model? Shouldn’t that be possible?

Thanks.


#5

Hi @erlebach,

you’re right that assignment appears in the model but not the guide. This is because

During inference, TraceEnum_ELBO will marginalize out the assignments of datapoints to clusters.

meaning that the guide doesn’t need to sample those values. When we write pyro.sample(..., infer={"enumerate": "parallel"}) in the model, those sample statements are treated specially by being exactly summed out, so the guide does not need to approximate them. This is a relatively new technique in Pyro, and leads to more accurate gradient estimates.

BTW you’re welcome to clarify the language in the tutorial once you feel you understand what’s going on. Just send us a PR.


#7

May I ask what AutoDelta (as written below) actually returns? I am working on a guide to replace it (just to learn):

global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']))

Returns? Here is my manual guide:
Notice that there is no return value.

Another point concerns the gradient norms. If I leave the “assignment” statement in the global_guide below, the gradient norms only decay (as in the example) for “scale”. “weights” and “locs” decay slowly and are smooth curves .However, if I remove the “assignment” statement from the global guide, I can reproduce the plot in the example. Why would the presence or absence of the “assignment” statement impact the gradient norms? I really do not understand this. Apparently, redundant terms in the guide can lead to seemingly incorrect results. It is strange because you stated that “assignment” was not required, which is why it is not included in the poutine.block() statement.

pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');
pyplot.show()

======

@config_enumerate(default='parallel')
@poutine.broadcast
def global_guide(data):
    auto_weights = pyro.param('auto_weights') # 0.5*torch.ones(K))
    auto_locs = pyro.param('auto_locs') #, 0.15*torch.ones(K))
    auto_scale = pyro.param('auto_scale') #, 0.2*torch.ones(1))

    pyro.sample('scale', dist.Delta(auto_scale))

    with pyro.iarange('components', K):
        locs = pyro.sample('locs', dist.Delta(auto_locs)) #, event_dim=1))
        weights = pyro.sample('weights', dist.Delta(auto_weights))

    with pyro.iarange('data', len(data)):
        # Local variables.
        # The presence or absence of the following line impacts the Gradient Norms during SVI. 
        assignment = pyro.sample('assignment', dist.Categorical(weights))

#8

Hi again,

I am looking at the LDA tutorial. Here is the code below. I removed all the sample statements from the parametrized guide and code ran without any messages of any kind. I did not bother checking the code, but I did expect an error message since the guide was now inconsistent with the model. Could somebody please explain this behavior? I am confused, again. Thanks.

"""
This example demonstrates how to marginalize out discrete assignment variables
in a Pyro model.

Our example model is Latent Dirichlet Allocation. While the model in this
example does work, it is not the recommended way of coding up LDA in Pyro.
Whereas the model in this example treats documents as vectors of categorical
variables (vectors of word ids), it is usually more efficient to treat
documents as bags of words (histograms of word counts).
"""
from __future__ import absolute_import, division, print_function

import argparse
import functools
import logging

import torch
from torch import nn
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.optim import Adam

logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO)


# This is a fully generative model of a batch of documents.
# data is a [num_words_per_doc, num_documents] shaped array of word ids
# (specifically it is not a histogram). We assume in this simple example
# that all documents have the same number of words.
def model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample("topic_words",
                                  dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        if data is not None:
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (args.num_words_per_doc, args.num_docs)
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
        with pyro.plate("words", args.num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
                               obs=data)

    return topic_weights, topic_words, data


# We will use amortized inference of the local topic variables, achieved by a
# multi-layer perceptron. We'll wrap the guide in an nn.Module.
def make_predictor(args):
    layer_sizes = ([args.num_words] +
                   [int(s) for s in args.layer_sizes.split('-')] +
                   [args.num_topics])
    logging.info('Creating MLP with sizes {}'.format(layer_sizes))
    layers = []
    for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
        layer = nn.Linear(in_size, out_size)
        layer.weight.data.normal_(0, 0.001)
        layer.bias.data.normal_(0, 0.001)
        layers.append(layer)
        layers.append(nn.Sigmoid())
    return nn.Sequential(*layers)


def parametrized_guide(predictor, data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics) / args.num_topics,
            constraint=constraints.positive)
    topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(args.num_topics, args.num_words) / args.num_words,
            constraint=constraints.positive)
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        if torch._C._get_tracing_state():
            counts = torch.eye(1024)[data[:, ind]].sum(0).t()
        else:
            counts = torch.zeros(args.num_words, ind.size(0))
            counts.scatter_add_(0, data[:, ind], torch.tensor(1.).expand(counts.shape))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))


def main(args):
    logging.info('Generating data')
    pyro.set_rng_seed(0)
    # We can generate synthetic data directly by calling the model.
    true_topic_weights, true_topic_words, data = model(args=args)

    # We'll train using SVI.
    logging.info('-' * 40)
    logging.info('Training on {} documents'.format(args.num_docs))
    predictor = make_predictor(args)
    guide = functools.partial(parametrized_guide, predictor)
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    elbo = Elbo(max_plate_nesting=2)
    optim = Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(data, args=args, batch_size=args.batch_size)
        if step % 10 == 0:
            logging.info('{: >5d}\t{}'.format(step, loss))
    loss = elbo.loss(model, guide, data, args=args)
    logging.info('final loss = {}'.format(loss))


if __name__ == '__main__':
    assert pyro.__version__.startswith('0.3.0')
    parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation")
    parser.add_argument("-t", "--num-topics", default=8, type=int)
    parser.add_argument("-w", "--num-words", default=1024, type=int)
    parser.add_argument("-d", "--num-docs", default=1000, type=int)
    parser.add_argument("-wd", "--num-words-per-doc", default=64, type=int)
    parser.add_argument("-n", "--num-steps", default=1000, type=int)
    parser.add_argument("-l", "--layer-sizes", default="100-100")
    parser.add_argument("-lr", "--learning-rate", default=0.001, type=float)
    parser.add_argument("-b", "--batch-size", default=32, type=int)
    parser.add_argument('--jit', action='store_true')
    args = parser.parse_args()
    main(args)

#9

You can use pyro.enable_validation(True) to show warnings and error messages that are suppressed by default. Please post the modified code you ran along with expected and actual output if you have a more specific question about it.


#10

Thanks for the advice. I did as you suggested and indeed got the required messages, after I set validate_args=False in a few samples:

topic_words = pyro.sample("topic_words",
        dist.Dirichlet(torch.ones(args.num_words) / args.num_words, validate_args=False))  #model
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights, validate_args=False))  # model
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior,validate_args=False))  # guide
pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1, validate_args=False))  # guide

which had to do with an argument not being within the support of the distribution (I am not sure how that can possibly happen with a Delta distribution.

Thanks. I will post a specific question about the LDA tutorial in another post categorized as tutorial.

Gordon