I am having issues with stability in any mixture of categoricals model I create, similar to this issue described here (which has no responses), although my setup is a little different.
In the end I am trying to write a “mixture of unigrams” model, like a simplified Latent Dirichlet Alocation with Pyro, with some tweaks once I have the basic model set up. I used the pyro amortized LDA example as a starting place, but then when I couldnt get good inference with the amortized guide I switched to something even simpler. The issue is that the training seems unstable for any version I set up. It doesn’t learn the document level topic posterior.
Here is the simplest code for reproducing the issue:
import logging
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.optim import ClippedAdam
from pyro.ops.indexing import Vindex
#torch.set_default_tensor_type("torch.cuda.FloatTensor")
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
num_topics = 8
num_words = 1024
num_docs = 10000
num_words_per_doc = 100
@config_enumerate
def model(data=None):
# Globals.
topic_weights = pyro.sample(
"topic_weights", dist.Dirichlet(torch.ones(num_topics) / num_topics)
)
with pyro.plate("topics", num_topics):
topic_words = pyro.sample(
"topic_words", dist.Dirichlet(torch.ones(num_words) / num_words)
)
# Locals.
with pyro.plate("documents", num_docs) as ind:
doc_topic = pyro.sample("doc_topic", dist.Categorical(topic_weights))
with pyro.plate("words", num_words_per_doc):
data = pyro.sample(
"doc_words",
dist.Categorical(Vindex(topic_words)[doc_topic]),
infer={"enumerate": "parallel"}, obs=data
)
return topic_weights, topic_words, doc_topic, data
def parametrized_guide(data):
# Use a conjugate guide for global variables.
topic_weights_posterior = pyro.param(
"topic_weights_posterior",
lambda: torch.ones(num_topics),
constraint=constraints.positive,
)
topic_weights = pyro.sample(
"topic_weights", dist.Dirichlet(topic_weights_posterior)
)
topic_words_posterior = pyro.param(
"topic_words_posterior",
lambda: torch.ones(num_topics, num_words),
constraint=constraints.positive,
)
with pyro.plate("topics", num_topics):
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
doc_topic_posterior = pyro.param(
"doc_topic_posterior",
lambda: torch.ones(num_docs, num_topics),
#constraint=constraints.simplex,
)
with pyro.plate("documents", num_docs) as ind:
doc_topic = pyro.sample("doc_topic", dist.Categorical(torch.nn.Softmax(dim=1)(doc_topic_posterior)))
logging.info("Generating data")
pyro.set_rng_seed(4)
pyro.clear_param_store()
# We can generate synthetic data directly by calling the model.
true_topic_weights, true_topic_words, true_doc_topic, data = model()
optim = pyro.optim.Adam({'lr': 0.001, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=2)
pyro.clear_param_store()
global_guide = parametrized_guide
svi = SVI(model, global_guide, optim, loss=elbo)
logging.info("Step\tLoss")
losses = []
for step in range(1000):
loss = svi.step(data)
losses.append(loss)
if step % 10 == 0:
logging.info("{: >5d}\t{}".format(step, loss))
I have tried many things with this basic model to get this working. I have tried a simplex constraint and the Softmax currently implemented for the topic posterior in this guide. I have tried using the auto guide. I have tried the amortized guide with a neural net. I really think there may be an issue in the backprop here, because the most common outcome if I trace the params is that the topic words converge to all be the same, and the doc topic posteriors are just uniforms. You can check the doc topic posterior here:
pyro.get_param_store()["doc_topic_posterior"]
And to see the topic_word posterior convergence you can run :
guide_trace = poutine.trace(global_guide).get_trace(data) # record the globals
trained_model = poutine.replay(model, trace=guide_trace) # replay the globals
trace = poutine.trace(trained_model).get_trace(data)
print(trace.nodes['topic_words']['value'].detach().numpy())