Weights not updating in Amortized Network

Hi,

I have created a very minimal version of the LDA code provided in the Examples pyro.ai .
This code has an amortized network, which has been registered in the parametrized_guide method. However, when run, you will notice that the weights to not get updated (I have printed out the bias). I do not understand why. Could somebody help me? Thanks.

Gordon

===================

from __future__ import absolute_import, division, print_function

import sys
import argparse
import functools
import logging

import torch
from torch import nn
from torch.distributions import constraints
from torch.distributions import beta

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, Trace_ELBO
from pyro.optim import Adam

logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO)


def model(data=None, args=None, batch_size=None):
# Globals.
with pyro.plate("topics", args.num_topics):
    doc_words = pyro.sample("doc_words", dist.Normal(0., 1.))
    print("doc_words shape: ", doc_words.shape)

# We will use amortized inference of the local topic variables, achieved by a
# multi-layer perceptron. We'll wrap the guide in an nn.Module.
def make_predictor(args):
layer_sizes = ([args.num_words] +
               [int(s) for s in args.layer_sizes.split('-')] +
               [args.num_topics])
logging.info('Creating MLP with sizes {}'.format(layer_sizes))
layers = []
for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
    layer = nn.Linear(in_size, out_size)
    layer.weight.data.normal_(0, 0.001)
    layer.bias.data.normal_(0, 0.001)
    layers.append(layer)
    layers.append(nn.Sigmoid())
return nn.Sequential(*layers)

def parametrized_guide(predictor, data, args, batch_size=None):

pyro.module("predictor", predictor)

"""
# Test code
ccc = torch.ones([1,78])   # Temporarily set batch_size to 4
doc_topics = predictor(ccc)
pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
print("doc_topics= ", doc_topics)  # unchanging
return
"""

param_store = pyro.get_param_store()
# print("parameter names: ", param_store.get_all_param_names())
# The bias parameter is unchanging. The network is not training!!!! <<<=======
print("***: ", pyro.param("predictor$$$4.bias"))

# Use an amortized guide for local variables.
# For testing, I use a constant fake_data tensor. When calling parametrized_guide multiple times, 
	# I expect the weights of the Neural Network to change as the ELBO is maximized, and the doc_topics
	# to change accordingly. But that is not the case. The weights and doc_topics remain constant. 
	# WHY? 
with pyro.plate("topics", args.num_topics):
    fake_data = torch.ones([1,78])   # Temporarily set batch_size to 1
    doc_topics = predictor(fake_data)
    # The weights are not changing
    print("doc_topics= ", doc_topics) 
    pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))


def main(args):
model(args=args)

# Neural Network
predictor = make_predictor(args)

# data generation
V = args.num_words
m = data = beta.Beta(torch.tensor([10.]), torch.tensor([10.]) * torch.ones([args.num_docs, V]))
data = m.sample()

# We'll train using SVI.
logging.info('-' * 40)
logging.info('Training on {} documents'.format(args.num_docs))
predictor = make_predictor(args)
guide = functools.partial(parametrized_guide, predictor)
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=2)
optim = Adam({'lr': args.learning_rate})
svi = SVI(model, guide, optim, elbo)
logging.info('Step\tLoss')
for step in range(args.num_steps):
    loss = svi.step(data, args=args, batch_size=args.batch_size)
    if step % 10 == 0:
        logging.info('{: >5d}\t{}'.format(step, loss))
loss = elbo.loss(model, guide, data, args=args)
logging.info('final loss = {}'.format(loss))


if __name__ == '__main__':
assert pyro.__version__.startswith('0.3.0')
parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation")
parser.add_argument("-t", "--num-topics", default=5, type=int)
parser.add_argument("-w", "--num-words", default=78, type=int)
parser.add_argument("-d", "--num-docs", default=1, type=int)
parser.add_argument("-wd", "--num-words-per-doc", default=64, type=int)
parser.add_argument("-n", "--num-steps", default=1000, type=int)
parser.add_argument("-l", "--layer-sizes", default="100-100")
parser.add_argument("-lr", "--learning-rate", default=0.001, type=float)
parser.add_argument("-b", "--batch-size", default=4, type=int)
parser.add_argument('--jit', action='store_true')
args = parser.parse_args()
main(args)

I’m not sure I understand what you’re trying to demonstrate with this example - there’s no observed variable in the model, so it’s unaffected by data. You’re training the guide to match the prior, and since you initialized the bias very close to the prior mean, the bias won’t move much.

Ok, I will revisit. This came about when working with a modification of the LDA example, and I wanted to provide a shorter code. Stay tuned. Thanks for replying.

Hi,

I am still trying to get the LDA demonstration code to work. To speed up training, I have decreased the number of documents to 100 and the vocabulary size (num_words) to 150. Each document has 64 words and the batch size is 10. I have not otherwise changed the parameters of the problem, although I have experimented with networks of size 10-10 (instead of 100-100) to decrease potential overfitting.


Here is the LDA code from the pyro examples page (pyro 0.3):

"""
This example demonstrates how to marginalize out discrete assignment variables
in a Pyro model.

Our example model is Latent Dirichlet Allocation. While the model in this
example does work, it is not the recommended way of coding up LDA in Pyro.
Whereas the model in this example treats documents as vectors of categorical
variables (vectors of word ids), it is usually more efficient to treat
documents as bags of words (histograms of word counts).
"""
from __future__ import absolute_import, division, print_function

import argparse
import functools
import logging

import torch
from torch import nn
from torch.distributions import constraints

import sys # GE
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.optim import Adam

logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO)
pyro.enable_validation(True)


# This is a fully generative model of a batch of documents.
# data is a [num_words_per_doc, num_documents] shaped array of word ids
# (specifically it is not a histogram). We assume in this simple example
# that all documents have the same number of words.
def model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample("topic_words",
                                  dist.Dirichlet(torch.ones(args.num_words) / args.num_words, validate_args=False))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        if data is not None:
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (args.num_words_per_doc, args.num_docs)
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights, validate_args=False))
        with pyro.plate("words", args.num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
                               obs=data)

    #print("topic_weights= ", topic_weights); sys.exit() # GE
    return topic_weights, topic_words, data


# We will use amortized inference of the local topic variables, achieved by a
# multi-layer perceptron. We'll wrap the guide in an nn.Module.
def make_predictor(args):
    layer_sizes = ([args.num_words] +
                   [int(s) for s in args.layer_sizes.split('-')] +
                   [args.num_topics])
    logging.info('Creating MLP with sizes {}'.format(layer_sizes))
    layers = []
    for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
        layer = nn.Linear(in_size, out_size)
        layer.weight.data.normal_(0, 0.001)
        layer.bias.data.normal_(0, 0.001)
        layers.append(layer)
        layers.append(nn.Sigmoid())
    return nn.Sequential(*layers)


def parametrized_guide(predictor, data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics) / args.num_topics,
            constraint=constraints.positive)
    topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(args.num_topics, args.num_words) / args.num_words,
            constraint=constraints.positive)
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior,validate_args=False))
        pass

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        if torch._C._get_tracing_state():
            counts = torch.eye(1024)[data[:, ind]].sum(0).t()
        else:
            counts = torch.zeros(args.num_words, ind.size(0))
            # only works if vocab > # words in doc
            counts.scatter_add_(0, data[:, ind], torch.tensor(1.).expand(counts.shape))
        doc_topics = predictor(counts.transpose(0, 1))
        #print("doc_topics= ", doc_topics)
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1, validate_args=False))


def main(args):
    logging.info('Generating data')
    pyro.set_rng_seed(0)  # GE: always the same seed
    # We can generate synthetic data directly by calling the model.
    true_topic_weights, true_topic_words, data = model(args=args)

    # We'll train using SVI.
    logging.info('-' * 40)
    logging.info('Training on {} documents'.format(args.num_docs))
    predictor = make_predictor(args)
    guide = functools.partial(parametrized_guide, predictor)
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    elbo = Elbo(max_plate_nesting=2)
    optim = Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(data, args=args, batch_size=args.batch_size)
        if step % 10 == 0:
            print("model, topic_weights_posterior= ", pyro.param("topic_weights_posterior"))
            logging.info('{: >5d}\t{}'.format(step, loss))
    loss = elbo.loss(model, guide, data, args=args)
    logging.info('final loss = {}'.format(loss))


if __name__ == '__main__':
    assert pyro.__version__.startswith('0.3.0')
    parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation")
    parser.add_argument("-t", "--num-topics", default=8, type=int)
    parser.add_argument("-w", "--num-words", default=150, type=int)
    #parser.add_argument("-w", "--num-words", default=1024, type=int) # orig
    #parser.add_argument("-d", "--num-docs", default=100, type=int)
    parser.add_argument("-d", "--num-docs", default=100, type=int)
    #parser.add_argument("-d", "--num-docs", default=1000, type=int) # orig
    parser.add_argument("-wd", "--num-words-per-doc", default=64, type=int)
    parser.add_argument("-n", "--num-steps", default=1000, type=int)
    parser.add_argument("-l", "--layer-sizes", default="10-10")
    #parser.add_argument("-l", "--layer-sizes", default="100-100")
    parser.add_argument("-lr", "--learning-rate", default=0.001, type=float)
    #parser.add_argument("-b", "--batch-size", default=32, type=int)
    parser.add_argument("-b", "--batch-size", default=10, type=int)
    parser.add_argument('--jit', action='store_true')
    args = parser.parse_args()
    main(args)

Running the code, and printing out pyro.params(“topic_weights_posterior”) every 10 iterations, one finds:

python b.py
     1023 Generating data
     1081 ----------------------------------------
     1081 Training on 100 documents
     1081 Creating MLP with sizes [150, 10, 10, 8]
     1081 Step	Loss
model, topic_weights_posterior=  tensor([0.1251, 0.1251, 0.1251, 0.1251, 0.1251, 0.1251, 0.1251, 0.1251],
       grad_fn=<AddBackward0>)
     1112     0	47177.16015625
model, topic_weights_posterior=  tensor([0.1261, 0.1259, 0.1261, 0.1261, 0.1261, 0.1262, 0.1260, 0.1261],
       grad_fn=<AddBackward0>)
     1313    10	45491.24609375
model, topic_weights_posterior=  tensor([0.1271, 0.1268, 0.1272, 0.1270, 0.1269, 0.1271, 0.1269, 0.1271],
       grad_fn=<AddBackward0>)
     1512    20	102155.7265625
model, topic_weights_posterior=  tensor([0.1279, 0.1276, 0.1282, 0.1281, 0.1277, 0.1279, 0.1277, 0.1282],
       grad_fn=<AddBackward0>)
     1708    30	79620.7421875
model, topic_weights_posterior=  tensor([0.1288, 0.1283, 0.1291, 0.1294, 0.1287, 0.1288, 0.1284, 0.1291],
       grad_fn=<AddBackward0>)
     1902    40	95385.328125
model, topic_weights_posterior=  tensor([0.1295, 0.1290, 0.1298, 0.1305, 0.1296, 0.1295, 0.1292, 0.1300],
       grad_fn=<AddBackward0>)
     2097    50	102654.625
model, topic_weights_posterior=  tensor([0.1304, 0.1297, 0.1306, 0.1315, 0.1304, 0.1302, 0.1298, 0.1310],
       grad_fn=<AddBackward0>)
     2292    60	92906.4296875
model, topic_weights_posterior=  tensor([0.1315, 0.1304, 0.1314, 0.1324, 0.1311, 0.1311, 0.1304, 0.1318],
       grad_fn=<AddBackward0>)
     2494    70	106757.453125

In other words, the topic_weights_posterior all have approximately the same values. After one 1000 iterations, equipartition of topic weights remains. Of course, they do not sum to 1, but that is understandable given the parametrized guide construction:

model, topic_weights_posterior=  tensor([0.2226, 0.2241, 0.2289, 0.2266, 0.2265, 0.2332, 0.2175, 0.2269],
       grad_fn=<AddBackward0>)
    20826   970	93173.59375
model, topic_weights_posterior=  tensor([0.2240, 0.2250, 0.2301, 0.2279, 0.2275, 0.2343, 0.2187, 0.2281],
       grad_fn=<AddBackward0>)
    21028   980	95545.015625
model, topic_weights_posterior=  tensor([0.2256, 0.2259, 0.2311, 0.2291, 0.2291, 0.2352, 0.2198, 0.2293],
       grad_fn=<AddBackward0>)
    21231   990	50404.92578125
    21481 final loss = 91240.0625

I really would like to know whether the LDA only served as an example, that was not actually debugged, or is it functional. How would you test whether the LDA example code is running correctly or not? I believe it does not produce correct results, so an example that demonstrates that it does would be very useful.

Thank you.

lda is a strange model and it can behave strangely. getting variational inference to work for lda is particularly tricky. these people seem to have gotten it to work with some tricks but it’s not the sort of thing you expect to work out of the box. this example was more so meant as a way to demonstrate how to setup a particular model. if it were a full working example we’d have done things like report test set perplexities on some dataset

Ok, thank you for your advice. Frankly, I was convinced I understood the model and was all set to begin modifying it. But I guess I am not ready. I had the paper but had not read it as yet.

Gordon

The LDA variant in that paper isn’t actually all that different in spirit from the example in Pyro, since we have reparametrized gradient estimators for the Dirichlet and Gamma distributions. As Martin says, the example hasn’t been properly tuned and evaluated, isn’t using any tricks like batch normalization discussed in the paper, and still may not work when modified and applied an arbitrary new dataset (pull requests welcome!). However, another explanation for the lack of parameter movement you’re currently seeing, rather than a failure to converge, is that the dataset in the example script is drawn from the prior predictive distribution and the guide is already initialized close to the prior.