I am new to pyro. To try it I wanted to write a character-level RNN for text generation following the tutorial on Bayesian regression.
The idea is to write a pytorch RNN, lift it using pyro.random_module
and try to infer all the parameters of the network using SVI.
I managed to write the model and the guide but if I try to infer all the parameters I immediately run into a UserWarning: Encountered NAN log_prob_sum at site
and the loss is already NaN
at the first iteration.
Following a post I found here, I tried to combine probabilistic and non-probabilistic parameters. Interestingly the inference seems to work with probabilistic parameters for the encoder and the decoder and non-probabilistic weights for the GRU, but not the other way around.
Does anyone ran into a similar issue? Any advice?
The network is the following:
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.hidden = self.init_hidden()
self.encoder = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
self.decoder = nn.Linear(hidden_size, output_size)
def forward(self, input):
input = self.encoder(input)
output, self.hidden = self.gru(
input.view(len(input), 1, -1), self.hidden)
output = self.decoder(output.view(len(input), -1))
return log_softmax(output, dim=1)
def init_hidden(self):
self.hidden = torch.zeros(self.n_layers, 1, self.hidden_size)
Model and Guide:
rnn = RNN(n_characters, hidden_size, n_characters)
def centered_normal(*shape):
return Normal(torch.zeros(*shape), torch.ones(*shape)).independent(1)
def model(input, target):
priors = {
'encoder.weight': centered_normal(n_characters, hidden_size),
'gru.weight_ih_l0': centered_normal(3 * hidden_size, hidden_size),
'gru.weight_hh_l0': centered_normal(3 * hidden_size, hidden_size),
'gru.bias_ih_l0': centered_normal(3 * hidden_size),
'gru.bias_hh_l0': centered_normal(3 * hidden_size),
'decoder.weight': centered_normal(n_characters, hidden_size),
'decoder.bias': centered_normal(n_characters)
lifted_module = pyro.random_module("rnn", rnn, priors)
lifted_reg_model = lifted_module()
output = lifted_reg_model(input)
pyro.sample("obs", Categorical(logits=output), obs=target)
def variable_normal(name, *shape):
loc = pyro.param(
name+"_loc", torch.randn(*shape))
scale = softplus(pyro.param(
name+"_scale", torch.randn(*shape)))
return Normal(loc, scale).independent(1)
def guide(input, target):
dists = {
'encoder.weight': variable_normal("enc_w", n_characters, hidden_size),
'gru.weight_ih_l0': variable_normal("gru_w_ih", 3 * hidden_size, hidden_size),
'gru.weight_hh_l0': variable_normal("gru_w_hh", 3 * hidden_size, hidden_size),
'gru.bias_ih_l0': variable_normal("gru_b_ih", 3 * hidden_size),
'gru.bias_hh_l0': variable_normal("gru_b_hh", 3 * hidden_size),
'decoder.weight': variable_normal("dec_w", n_characters, hidden_size),
'decoder.bias': variable_normal("dec_b", n_characters)
lifted_module = pyro.random_module("rnn", rnn, dists)
return lifted_module()
Inference loop:
inference = SVI(model, guide, Adam({"lr": 0.005}), loss=Trace_ELBO())
for epoch in range(1, n_epochs + 1):
for i, (input, target) in enumerate(data):
loss = inference.step(input, target)
loss = loss / len(input)
if i % print_every == 0:
print('[epoch {}, iteration {}, loss = {}]'.format(epoch, i, loss))
print(evaluate_samples(), '\n')