Parameter collapse in simple model with SVI

I am having issues with stability in any mixture of categoricals model I create, similar to this issue described here (which has no responses), although my setup is a little different.

In the end I am trying to write a “mixture of unigrams” model, like a simplified Latent Dirichlet Alocation with Pyro, with some tweaks once I have the basic model set up. I used the pyro amortized LDA example as a starting place, but then when I couldnt get good inference with the amortized guide I switched to something even simpler. The issue is that the training seems unstable for any version I set up. It doesn’t learn the document level topic posterior.

Here is the simplest code for reproducing the issue:

import logging

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import ClippedAdam

from pyro.ops.indexing import Vindex

#torch.set_default_tensor_type("torch.cuda.FloatTensor")
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

num_topics = 8
num_words = 1024
num_docs = 10000
num_words_per_doc = 100

@config_enumerate
def model(data=None):
    # Globals.
    topic_weights = pyro.sample(
        "topic_weights", dist.Dirichlet(torch.ones(num_topics) / num_topics) 
    )
    with pyro.plate("topics", num_topics):
        topic_words = pyro.sample(
            "topic_words", dist.Dirichlet(torch.ones(num_words) / num_words)
        )


    # Locals.
    with pyro.plate("documents", num_docs) as ind:
        doc_topic = pyro.sample("doc_topic", dist.Categorical(topic_weights))
        with pyro.plate("words", num_words_per_doc):
            data = pyro.sample(
                "doc_words",
                dist.Categorical(Vindex(topic_words)[doc_topic]),
                infer={"enumerate": "parallel"}, obs=data
            )
    return topic_weights, topic_words, doc_topic, data

def parametrized_guide(data):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        lambda: torch.ones(num_topics),
        constraint=constraints.positive,
    )
    topic_weights = pyro.sample(
        "topic_weights", dist.Dirichlet(topic_weights_posterior) 
    )
    
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(num_topics, num_words),
        constraint=constraints.positive,
    )
    with pyro.plate("topics", num_topics):
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
    
    doc_topic_posterior = pyro.param(
        "doc_topic_posterior",
        lambda: torch.ones(num_docs, num_topics),
        #constraint=constraints.simplex,
    )
    with pyro.plate("documents", num_docs) as ind:
        doc_topic = pyro.sample("doc_topic", dist.Categorical(torch.nn.Softmax(dim=1)(doc_topic_posterior)))

logging.info("Generating data")
pyro.set_rng_seed(4)
pyro.clear_param_store()

# We can generate synthetic data directly by calling the model.
true_topic_weights, true_topic_words, true_doc_topic, data = model()

optim = pyro.optim.Adam({'lr': 0.001, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=2)

pyro.clear_param_store()
global_guide = parametrized_guide
svi = SVI(model, global_guide, optim, loss=elbo)

logging.info("Step\tLoss")
losses = []
for step in range(1000):
    loss = svi.step(data)
    losses.append(loss)
    if step % 10 == 0:
        logging.info("{: >5d}\t{}".format(step, loss))

I have tried many things with this basic model to get this working. I have tried a simplex constraint and the Softmax currently implemented for the topic posterior in this guide. I have tried using the auto guide. I have tried the amortized guide with a neural net. I really think there may be an issue in the backprop here, because the most common outcome if I trace the params is that the topic words converge to all be the same, and the doc topic posteriors are just uniforms. You can check the doc topic posterior here:

pyro.get_param_store()["doc_topic_posterior"]

And to see the topic_word posterior convergence you can run :

guide_trace = poutine.trace(global_guide).get_trace(data)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)  # replay the globals
trace = poutine.trace(trained_model).get_trace(data)
print(trace.nodes['topic_words']['value'].detach().numpy())

I have also tried initializing with something other than ones, in case that was the issue, but all the topic word and document topics always collapse to the same distributions.

your model does not appear to have any observed data (sample statements with an obs keyword)

Ah! Thank you. I must have omitted that when creating the simplified example. I have edited the code above appropriately, but this does not fix the issue.

have you tried subsampling data? generally speaking batch gradient descent (i.e. computing gradients w.r.t. the entire dataset) works poorly because it’s easy to get stuck in bad local optima. the stochasticity can be helpful in escaping such bad local optima and exploring the space more fully.

I did some more extensive testing around batching, and it really doesn’t seem to help at all. The inference really does not appear to learn anything, but it does seem to fix the convergence? The document level assignments are not better than random, the learned topics are meaningless. It seems very odd this is happening, I think there must be something wrong with the code above but I don’t know what…

I believe you need to mark the site "doc_topic" as enumerated in the guide if you want to use TraceEnum_ELBO for inference in this model:

...
# no config_enumerate on the model or guide
def model(...):
   ...

def guide(...):
    ...
    with pyro.plate("documents", ...):
        pyro.sample("doc_topic", dist.Categorical(doc_topic_posterior), infer={"enumerate": "parallel"})

You also don’t need to enumerate "doc_words" in the model, since it’s observed.

See our enumeration and mixture model tutorials for details and background.

1 Like

This has fixed this problem! Thank you so much.

1 Like