Parameter collapse training with amortized guide

I have a mixture model that works really well when I maintain posteriors on each document’s individual topic, but that seems to break down when I switch to the amortized guide.

My model is based on the Amortized LDA tutorial, and the model is essentially the one I describe here (a mixture of unigrams):

I have tried a number of things to try to get this working. I have tried many combinations of learning rate and learning rate decay schedule, I have added KL-annealing with various annealing schedules, I have tried different batch sizes. I have gone back and forth between “subsampling” and standard minibatching. Nothing seems to allow the neural-net-based model to come close to the model with local variables.

With the neural net I see general parameter collapse – the network predicts roughly the same topic for all documents, the topic word distributions all converge.

The reason I want this amortized guide to work is because I have a dataset of hundreds of millions of documents I want to cluster using this method, and I obviously cannot fit all of that onto GPU so the amortized guide seems like the best fit.

I am wondering if there’s anything I can do to improve the performance of the amortized guide. It really just doesn’t learn the model at all right now.

It’s hard for us to be helpful unless you can provide mathematical details or source code at least for your model and guide.

However, one general thing I don’t understand about the setup in your previous topic is why you want to do approximate inference over doc_topic, amortized or otherwise, if it is just a categorical variable, as opposed to a Dirichlet in standard LDA. Why not just integrate it out exactly during training with model-side enumeration in TraceEnum_ELBO (as opposed to the guide-side enumeration in your previous post), then sample from its exact posterior given approximate samples of the global parameters using infer_discrete? This process is described in detail in our Gaussian mixture model tutorial.

You might also be interested in reading our more complete ProdLDA tutorial for a better approach to LDA-like models in Pyro.

Right, its because my actual model is a bit more complicated:

class Seq2Hist(torch.nn.Module):
    def __init__(self, D_in, D_out):
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        super(Seq2Hist, self).__init__()
        self.num_words = D_out

    def forward(self, x):
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        out = torch.zeros(self.num_words, x.shape[1]).scatter_add(
            0, x, torch.ones(x.shape)
        ).transpose(1, 0)
        return out

def model(data=None, args=None, batch_size=None, annealing_factor=1):
    with poutine.scale(None, annealing_factor):
        # Globals.
        topic_weights = pyro.sample(
                "topic_weights", dist.Dirichlet(10 * torch.ones(args.num_topics))
        with pyro.plate("topics", args.num_topics + 1):
            topic_words = pyro.sample(
                "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words),

    # Locals.
    with pyro.plate("documents", data.size(1) if data is not None else args.num_docs) as ind:
        with poutine.scale(None, annealing_factor):
            doc_background_prob = pyro.sample("doc_background_prob", 
                                              dist.Beta(10 * torch.ones(1), 10 * torch.ones(1))

            doc_topic = pyro.sample("doc_topic", dist.Categorical(topic_weights)) + 1

        with pyro.plate("words", args.num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_background_ind = pyro.sample(
                infer={"enumerate": "parallel"}
            topic_ind = (doc_topic * word_background_ind).type(torch.LongTensor)
            data = pyro.sample(
                "doc_words", dist.Categorical(Vindex(topic_words)[topic_ind]), obs=data

    return topic_weights, topic_words, doc_topic, data

class MLP(nn.Module):
    def __init__(self, args, eps=1):
        super(MLP, self).__init__()
        layer_sizes = (
            #  + [int(s) for s in args.layer_sizes.split("-")]
        )"Creating MLP with sizes {}".format(layer_sizes))
        self.eps = eps
        self.seq2hist = Seq2Hist(args.num_words_per_doc, args.num_words)
        self.layers = []
        # for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
        #     layer = nn.Linear(in_size, out_size)
        #     nn.init.xavier_uniform_(layer.weight)
        #     self.layers.append(layer)
        self.prob_layer = nn.Linear(layer_sizes[-1], 2)
        self.prob_scale_layer = nn.Linear(layer_sizes[-1], 1)
        self.topic_layer = nn.Linear(layer_sizes[-1], args.num_topics)
        self.topic_act = nn.Softmax(dim=-1)
        self.anneal_floor = 0

    # forward propagate input
    def forward(self, X):
        self.anneal_floor += 1
        # input to first hidden layer
        z = self.seq2hist(X)
        # print(z.shape)
        for layer in self.layers:
            z = nn.ReLU()(layer(z))
        background_prob = nn.Softmax(dim=-1)(self.prob_layer(z))
        background_prob_scale = nn.ReLU()(self.prob_scale_layer(z)) + self.eps / self.anneal_floor
        background_prob_prior = background_prob * background_prob_scale
        topic_dist = self.topic_act(self.topic_layer(z))
        return background_prob_prior, topic_dist

# @config_enumerate
def parametrized_guide(predictor, data, args, batch_size=None, annealing_factor=1):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        lambda: torch.ones(args.num_topics),
    topic_words_posterior = pyro.param(
        lambda: torch.ones(args.num_topics + 1, args.num_words),
    with poutine.scale(None, annealing_factor):
        pyro.sample("topic_weights", dist.Dirichlet(topic_weights_posterior))
        with pyro.plate("topics", args.num_topics + 1):
            pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", data.size(1)):
        background_prob_posterior, doc_topic_posterior = predictor(data)
        with poutine.scale(None, annealing_factor):
            # print(background_prob_posterior)
                        dist.Beta(background_prob_posterior[:, 0], 
                                  background_prob_posterior[:, 1]))
            # print(doc_topic_posterior)
            pyro.sample("doc_topic", dist.Categorical(doc_topic_posterior),
                        infer={"enumerate": "parallel"})

At the moment I am using a linear model because I was hoping that would stabilize the loss landscape but is still flexible enough for this simple model.

I think because I need to infer the continuous background probability I cannot use the strategy you outline above right? Or if I can I’m not sure how. Otherwise I agree that would probably be ideal!

You can still follow the strategy I suggested in the model above - all you need to do is delete the doc_topic site from your guide, remove or ignore doc_topic_posterior from your predictor, and mark doc_topic as enumerated in the model.

The Gaussian mixture model tutorial I linked to above explains in detail how to use infer_discrete to sample from the posterior of doc_topic given samples from the guide for the other latent variables, including doc_background_prob.

Ok, that sounds great. Then do I need an amortized guide for the background prob? Because it’s a local variable bur not discrete?

You still need a guide site for doc_background_prob, and it is a local variable, but whether the guide is amortized or not is up to you.

1 Like

Got it! This was easy to implement and its fixed the problem of convergence in the topic-word distributions! Thank you so much!