VAE sampling on categorical distribution work around?

I have a VAE that’s used do a seq2seq for people names. The encoder takes in a name and decoder outputs the name, but the ELBO is plateauing really early and I’m assuming this is because the gradients can’t be calculated because the distribution is discrete(A categorical over alphabetical characters). Is there a way around this? Perhaps making it a continuous distribution that’d work in this case?

Hi, can you please provide a code snippet (ideally a runnable one) that demonstrates your model and its failure mode more concretely? It’s not clear from your post what the issue is, exactly - why do you need gradients through discrete variables, which I would have assumed are observed at training time, in a seq2seq VAE?

import os
import pyro
import torch
import torch.nn.functional as F
from const import DEVICE
from neural_net.seq2seq import Encoder, Decoder, PreTrainedDecoder
from util.convert import index_to_letter, FORMAT_DIST, SOS, EOS, PAD, LETTERS_COUNT, strings_to_tensor, \
    strings_to_index_tensor, strings_to_probs, printable_to_index, to_rnn_tensor, PRINTABLE_COUNT, pad_string, DROP_OUT
from util.nn_utils import mask_char
MAX_NAME_LENGTH = 6
MAX_OUTPUT_LENGTH = 10
FIRST_NAME_ADD = "first_name"
LAST_NAME_ADD = "last_name"
MIDDLE_NAME_ADD = "middle_name"
def generate_name(lstm: Decoder, address: str, mask: torch.tensor, hidd_cell_states: tuple = None,
                  max_name_length: int = MAX_NAME_LENGTH, batch_size: int = 1,
                  encoder_outputs: list = None, observed: list = None):
    """
    Takes an LSTM decorder and hidden state vector if one passed in. If no hidden state vector
    passed in assume that this is last name(Frank specified running on LN first)
    lstm: Decoder associated with name being generated
    address: The address to correlate pyro distribution with latent variables
    hidd_cell_states: Previous LSTM hidden state or empty hidden state
    max_name_length: The max name length allowed
    Full names are pre padded to mitigate vanishing gradient
    Observed inputs are post padded to have the model generate names straight away
    """
    # If supervised, convert the supervised strings into an index vector and observe characters at each timestep
    if observed is None:
        observed = [None] * max_name_length
    else:
        observed = list(map(lambda name: name + EOS, observed))
        observed = strings_to_index_tensor(observed, max_name_length)
    # If no hidden state is provided, initialize it with all 0s
    if hidd_cell_states is None:
        hidd_cell_states = lstm.init_hidden(batch_size=batch_size)
    input_tensor = strings_to_tensor([SOS] * batch_size, 1)
    names = [''] * batch_size
    for index in range(max_name_length):
        if encoder_outputs is None or mask is None:
            char_dist, hidd_cell_states = lstm.forward(input_tensor, hidd_cell_states, encoder_outputs, mask)
        else:
            char_dist, hidd_cell_states = lstm.forward(input_tensor, hidd_cell_states, encoder_outputs, mask)
        # Correlate index in name generation as the address
        sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategoricalStraightThrough(1, logits=char_dist), obs=observed[index])
        # Sampled char should be an index not a one-hot
        chars_at_indexes = []
        for i in range(len(sampled_indexes[0])):
            for j in range(len(sampled_indexes[0][i])):
                if sampled_indexes[0][i][j] == 1:
                    chars_at_indexes.append(index_to_letter(j))
        # Add sampled characters to names     
        for i, char in enumerate(chars_at_indexes):
            names[i] += char
        # input_tensor = strings_to_tensor(chars_at_indexes, 1)
        # TRYING: NOT SAMPLING AT EACH STEP AND INSTEAD PASSING ARGMAMX AS INPUT
        input_tensor = torch.zeros(input_tensor.shape).to(DEVICE)
        max_indices = torch.argmax(char_dist, dim=2)
        for i_batch, max_index in enumerate(max_indices[0]):
            input_tensor[0, i_batch, max_index] = 1.
    # Discard everything after EOS character
    # SemiSupervised Goal: Encoder takes prepadded strings and decoder takes postpadded strings 
    names = list(map(lambda name: name[:name.find(EOS)] if name.find(EOS) > -1 else name, names))
    return hidd_cell_states, names
class NameGenerator():
    """
    Generates names using a separate LSTM for first name, last name, middle name and a neural net
    using ELBO to parameterize NN for format classification.
    input_size: Should be the number of letters to allow
    hidden_size: Size of the hidden dimension in LSTM
    num_layers: Number of hidden layers in LSTM
    hidden_sz: Hidden layer size for LSTM RNN
    peak_prob: The max expected probability
    drop_out: The percent decoder drops the previous input for forward
    """
    def __init__(self, num_layers: int = 8, hidden_sz: int = 524, peak_prob: float = 0.9):
        super().__init__()
        # Model neural nets instantiation
        self.model_fn_lstm = Decoder(LETTERS_COUNT, hidden_sz, LETTERS_COUNT, num_layers=num_layers,
                                     max_length=MAX_NAME_LENGTH)
        # Guide neural nets instantiation
        self.guide_fn_lstm = Decoder(LETTERS_COUNT, hidden_sz, LETTERS_COUNT, num_layers=num_layers,
                                     max_length=MAX_NAME_LENGTH)
        # Instantiate encoder
        self.encoder_lstm = Encoder(PRINTABLE_COUNT, hidden_sz, num_layers=num_layers)
        # Hyperparameters
        self.peak_prob = peak_prob
        self.num_layers = num_layers
        self.hidden_sz = hidden_sz
    def model(self, x: list, z: dict = None):
        """
        Model for generating names representing p(x,z)
        x: Training data (name string)
        z: Optionally supervised latent values (dictionary of name/format values)
        """
        padded_x = list(map(lambda s: pad_string(s, MAX_NAME_LENGTH, pre_pad=False), x))
        index_tensors = strings_to_index_tensor(padded_x, MAX_NAME_LENGTH, printable_to_index)
        encoder_outputs = torch.zeros(MAX_NAME_LENGTH, len(x), self.num_layers * self.hidden_sz * 2).to(DEVICE)
        if z is None:
            # Unsupervised: Do ELBO Maximization
            pyro.module("model_fn_lstm", self.model_fn_lstm)
            z = {'name_format': None, FIRST_NAME_ADD: None, MIDDLE_NAME_ADD: None, LAST_NAME_ADD: None}
            initial_hidden = None
            fn_lstm = self.model_fn_lstm
        else:
            # Supervised: Do MLE
            pyro.module("guide_fn_lstm", self.guide_fn_lstm)
            pyro.module("encoder_lstm", self.encoder_lstm)
            initial_hidden = self.encoder_lstm.init_hidden(batch_size=len(x))
            fn_lstm = self.guide_fn_lstm
            x_formatted = to_rnn_tensor(index_tensors, letter_count=PRINTABLE_COUNT)
            for i in range(x_formatted.shape[0]):
                _, initial_hidden = self.encoder_lstm.forward(x_formatted[i].unsqueeze(0), initial_hidden)
                encoder_outputs[i] = torch.cat((initial_hidden[0], initial_hidden[1]), dim=2).squeeze(0)
        batch_sz = len(x)
        with pyro.plate("data", batch_sz):
            mask = mask_char(padded_x, MAX_NAME_LENGTH)
            _, first_names = generate_name(fn_lstm, FIRST_NAME_ADD, mask, initial_hidden,
                                           batch_size=batch_sz, encoder_outputs=encoder_outputs,
                                           observed=z[FIRST_NAME_ADD])
            full_names = list(map(lambda name: pad_string(name, MAX_NAME_LENGTH, pre_pad=False), first_names))
            probs = strings_to_probs(full_names, MAX_NAME_LENGTH, true_index_prob=self.peak_prob,
                                     index_function=printable_to_index)
            with pyro.plate("likelihood", MAX_NAME_LENGTH):
                pyro.sample("output", pyro.distributions.Categorical(probs), obs=index_tensors)
            return full_names
    def guide(self, x: list, z: dict = None):
        """
        Guide for approximation of the posterior q(z|x)
        x: Training data (name string)
        z: Optionally supervised latent values (dictionary of name/format values)
        """
        if z is not None: return
        pyro.module("guide_fn_lstm", self.guide_fn_lstm)
        pyro.module("encoder_lstm", self.encoder_lstm)
        batch_sz = len(x)
        with pyro.plate("data", batch_sz):
            initial_hidden = self.encoder_lstm.init_hidden(batch_size=batch_sz)
            x_padded = list(map(lambda s: pad_string(s, MAX_NAME_LENGTH, pre_pad=False), x))
            index_tensors = strings_to_index_tensor(x_padded, MAX_NAME_LENGTH, printable_to_index)
            x_formatted = to_rnn_tensor(index_tensors, letter_count=PRINTABLE_COUNT)
            encoder_outputs = torch.zeros(MAX_NAME_LENGTH, batch_sz, self.num_layers * self.hidden_sz * 2).to(DEVICE)
            for i in range(x_formatted.shape[0]):
                _, initial_hidden = self.encoder_lstm.forward(x_formatted[i].unsqueeze(0), initial_hidden)
                encoder_outputs[i] = torch.cat((initial_hidden[0], initial_hidden[1]), dim=2).squeeze(0)
            mask = mask_char(x_padded, MAX_NAME_LENGTH)
            _, first_names = generate_name(self.guide_fn_lstm, FIRST_NAME_ADD, mask,
                                           initial_hidden, batch_size=batch_sz,
                                           encoder_outputs=encoder_outputs)
            return first_names
    def load_pretrained_weights(self, folder="generators/nn_model/pretrain"):
        filepath_f = os.path.join(folder, "first_checkpt.pth.tar")
        filepath_m = os.path.join(folder, "middle_checkpt.pth.tar")
        filepath_l = os.path.join(folder, "last_checkpt.pth.tar")
        if not os.path.exists(filepath_f) or not os.path.exists(filepath_m) or not os.path.exists(filepath_l):
            raise Exception(f"Pretrained weights do not exist")
        self.model_fn_lstm.load_state_dict(torch.load(filepath_f, map_location=DEVICE)['weights'])
    def load_checkpoint(self, folder="nn_model", filename="checkpoint.pth.tar"):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(filepath):
            raise Exception(f"No model in path {folder}")
        save_content = torch.load(filepath, map_location=DEVICE)
        self.model_fn_lstm.load_state_dict(save_content['model_fn_lstm'])
        self.guide_fn_lstm.load_state_dict(save_content['guide_fn_lstm'])
        self.encoder_lstm.load_state_dict(save_content['encoder_lstm'])
    def save_checkpoint(self, folder="nn_model", filename="checkpoint.pth.tar"):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(folder):
            os.mkdir(folder)
        save_content = {
            'model_fn_lstm': self.model_fn_lstm.state_dict(),
            'guide_fn_lstm': self.guide_fn_lstm.state_dict(),
            'encoder_lstm': self.encoder_lstm.state_dict()
        }
        torch.save(save_content, filepath)
if __name__ == "__main__":
    import sys
    from util.config import load_json
    """
    Usage: python ss_train.py <Config File Path>
    """
    config = load_json(sys.argv[1])
    HIDDEN_SIZE = config['hidden_size'] if 'hidden_size' in config else 16
    SESSION_NAME = config['session_name']
    ss_vae = NameGenerator(hidden_sz=HIDDEN_SIZE, num_layers=1)
    ss_vae.load_checkpoint(folder="generators/nn_model", filename=f"{SESSION_NAME}.pth.tar")
    NAME = ['jason', 'mike', 'dylan']
    print(f"Reconstructing '{NAME}'...")
    for _ in range(5):
        print(ss_vae.guide(NAME))

Here’s the code. The RelaxedOneHotCategoricalStraightThrough used to be a categorical distribution, but I’m trying to use the Gumbel Softmax trick instead, but I don’t know how to implement it in Pyro.

Sorry, I’m still a little confused about your approach and why there are latent variables in your model at all at training time. Can you explain what you’re trying to do here or point to a paper you’re trying to reproduce? It looks like you want your model to fill in parts of a name that are left out at test time, but is any of your training data actually missing? I ask because you may not have to deal with the immediate issue you’re worried about at all.

This PyTorch tutorial looks like a nice introduction to training an RNN to generate names that you could easily adapt to name reconstruction by passing in partially missing names at test time - what’s different about your problem or your dataset?

I’m hoping the VAE will learn the distribution of characters at indexes conditioned on a hidden state that I’m encoding for name denoising, but as a precursor to that I’m trying simple seq2seq, which it’s failing to do. I’m using an Structured VAE, currently the MLE loss with full supervision is doing fine unsurprisingly, but ELBO is doing terrible and is actually making the weights worse. I’m assuming it’s cause I’m sampling from a categorical, which it cannot back prop on that is why I’m trying to use Gumbel Softmax instead. But currently I am trying to get seq2seq to work then eventually name denoising. The latent variables are the true character at each index.

But why do you need to sample any characters at training time at all? If you have a training dataset of complete names, wouldn’t it make more sense to do one of the following?

  1. Train a sequence model on your data as in that PyTorch tutorial on name generation, then at test time pass in a name with missing portions and sample the missing parts from the model.
  2. Inject some sort of plausible noise into your training data, perhaps by randomly adding/editing characters, permuting adjacent characters, or replacing some characters with a special “missing” value, and train a sequence-to-sequence RNN to reconstruct the noise-free string from the noisy string

I’m still not sure I understand your model, which seems like an attempt at a hybrid between 1 and 2, but both of these standard approaches seem simpler and more reliable regardless and should perform well on your problem.

You’re correct, but I’d like to be able to sample names with probabilities this setup allows that. The setup is essentially using an LSTM to encode a hidden state then the decoder the generate a categorical distribution from which to sample from at each character index. I also just want to know if it’s possible to build a better model with VAE than just pure RNN. I also thought of trying what you’re suggesting which is essentially a denoising autoencoder, but I’d still like to see if VAE can do better using inference. I really enjoy this back and forth we’re having, just out of curiosity why do you think the RNN method is “more reliable” you mean empirically as seen in its application or there’s something in the underlying architecture that makes it more reliable.

You’re correct, but I’d like to be able to sample names with probabilities this setup allows that. The setup is essentially using an LSTM to encode a hidden state then the decoder the generate a categorical distribution from which to sample from at each character index.

I’m afraid I don’t see how this is different from the standard RNN approaches I suggested or why you would expect your current model to do better or even work at all as written. You might want to write down in detail the math implemented in your model and guide, since it doesn’t seem to correspond to any standard VAE setup, and to look at this paper instead for a simple, canonical approach to sequence-to-sequence VAEs, in which Gaussian noise is injected into the continuous hidden state between encoder and decoder.

I also just want to know if it’s possible to build a better model with VAE than just pure RNN […] why do you think the RNN method is “more reliable”

If you’re doing this as a learning exercise, I would encourage you to either start from one of those PyTorch tutorials, which already work correctly and reliably for very similar problems, and make small modifications whose correctness you can verify mathematically and experimentally or to attempt to reproduce a simple, highly-cited paper such as the sequence-to-sequence VAE above. That way you won’t get stuck on the frustrating technical details necessary to make an entirely new and different model work, like estimating gradients through discrete samples.

If you’re applying this to a real-world problem, starting with something simple that already works is even more important.