Bayesian RNN (NaN loss issue)

Hi,

I am new to pyro. To try it I wanted to write a character-level RNN for text generation following the tutorial on Bayesian regression.
The idea is to write a pytorch RNN, lift it using pyro.random_module and try to infer all the parameters of the network using SVI.

I managed to write the model and the guide but if I try to infer all the parameters I immediately run into a UserWarning: Encountered NAN log_prob_sum at site and the loss is already NaN at the first iteration.

Following a post I found here, I tried to combine probabilistic and non-probabilistic parameters. Interestingly the inference seems to work with probabilistic parameters for the encoder and the decoder and non-probabilistic weights for the GRU, but not the other way around.

Does anyone ran into a similar issue? Any advice?

The network is the following:

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden = self.init_hidden()

        self.encoder = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        input = self.encoder(input)
        output, self.hidden = self.gru(
            input.view(len(input), 1, -1), self.hidden)
        output = self.decoder(output.view(len(input), -1))
        return log_softmax(output, dim=1)

    def init_hidden(self):
        self.hidden = torch.zeros(self.n_layers, 1, self.hidden_size)

Model and Guide:

rnn = RNN(n_characters, hidden_size, n_characters)

def centered_normal(*shape):
    return Normal(torch.zeros(*shape), torch.ones(*shape)).independent(1)

def model(input, target):
    priors = {
        'encoder.weight': centered_normal(n_characters, hidden_size),
        'gru.weight_ih_l0': centered_normal(3 * hidden_size, hidden_size),
        'gru.weight_hh_l0': centered_normal(3 * hidden_size, hidden_size),
        'gru.bias_ih_l0': centered_normal(3 * hidden_size),
        'gru.bias_hh_l0': centered_normal(3 * hidden_size),
        'decoder.weight': centered_normal(n_characters, hidden_size),
        'decoder.bias': centered_normal(n_characters)
    }
    lifted_module = pyro.random_module("rnn", rnn, priors)
    lifted_reg_model = lifted_module()
    lifted_reg_model.init_hidden()
    output = lifted_reg_model(input)
    pyro.sample("obs", Categorical(logits=output), obs=target)


def variable_normal(name, *shape):
    loc = pyro.param(
        name+"_loc", torch.randn(*shape))
    scale = softplus(pyro.param(
        name+"_scale", torch.randn(*shape)))
    return Normal(loc, scale).independent(1)

def guide(input, target):
    dists = {
        'encoder.weight': variable_normal("enc_w", n_characters, hidden_size),
        'gru.weight_ih_l0': variable_normal("gru_w_ih", 3 * hidden_size, hidden_size),
        'gru.weight_hh_l0': variable_normal("gru_w_hh", 3 * hidden_size, hidden_size),
        'gru.bias_ih_l0': variable_normal("gru_b_ih", 3 * hidden_size),
        'gru.bias_hh_l0': variable_normal("gru_b_hh", 3 * hidden_size),
        'decoder.weight': variable_normal("dec_w", n_characters, hidden_size),
        'decoder.bias': variable_normal("dec_b", n_characters)
    }
    lifted_module = pyro.random_module("rnn", rnn, dists)
    return lifted_module()

Inference loop:

inference = SVI(model, guide, Adam({"lr": 0.005}), loss=Trace_ELBO())
for epoch in range(1, n_epochs + 1):
    for i, (input, target) in enumerate(data):
        loss = inference.step(input, target)
        loss = loss / len(input)

        if i % print_every == 0:
            print('[epoch {}, iteration {}, loss = {}]'.format(epoch, i, loss))
            print(evaluate_samples(), '\n')

I don’t see any obvious problems with your code (though I may have missed something), but getting variational inference (or most other approximate inference algorithms) to work with nontrivial Bayesian neural net models and data sets is generally pretty difficult.

Some generic tips: you could update the initial values of your variational parameters to look more like standard neural net initializations, identify and add constraints or transformations to the variables/parameters that are producing NaNs, try Pyro’s ClippedAdam optimizer, or try Pyro’s HMC or NUTS gradient-based MCMC samplers.

Thanks for your answer.
I tried your suggestions and if I initialize the variational parameters closer to 0 there is no more NaN issue.
However, the loss does not decrease even after 20 000 iterations…
Does anyone knows why?

I will try HMC and NUTS, but I wonder if it is possible to make it work with SVI.

Here is my new guide:

def variable_normal(name, *shape):
    l = torch.empty(*shape, requires_grad=True)
    s = torch.empty(*shape, requires_grad=True)
    torch.nn.init.normal_(l, std=0.01)
    torch.nn.init.normal_(s, std=0.01)
    loc = pyro.param(name+"_loc", l)
    scale = nn.functional.softplus(pyro.param(name+"_scale", s))
    return Normal(loc, scale)


def guide(input, target):
    dists = {
        'encoder.weight': variable_normal("enc_w", n_characters, hidden_size),
        'gru.weight_ih_l0': variable_normal("gru_w_ih", 3 * hidden_size, hidden_size),
        'gru.weight_hh_l0': variable_normal("gru_w_hh", 3 * hidden_size, hidden_size),
        'gru.bias_ih_l0': variable_normal("gru_b_ih", 3 * hidden_size),
        'gru.bias_hh_l0': variable_normal("gru_b_hh", 3 * hidden_size),
        'decoder.weight': variable_normal("dec_w", n_characters, hidden_size),
        'decoder.bias': variable_normal("dec_b", n_characters)
    }
    lifted_module = pyro.random_module("rnn", rnn, dists)
    return lifted_module()