Debugging Gradient NaN's

I’m working on implementing Chris Manning’s Hierarchical Bayesian Domain Adaptation (https://nlp.stanford.edu/pubs/hbda.pdf) in Pyro, and I’m having an issue with the gradient/loss returning nan.

The model (center image at the top of the 3rd page in the PDF) has a common normal distribution, and then a domain-specific normal distribution (who’s mean is sampled from the common distribution), and that distribution acts as a prior on the weights.

I’ve included my code below – if I don’t include the domain-specific prior and just use the common distribution as the prior on the weights, it runs fine. However, once I include the domain-specific prior, svi.step() returns nan, as a result of the gradients returning nan. Additionally, if I don’t include any params in the common prior, and set all the variables to require_grad=False, it runs but returns garbage results. Any feedback would be appreciated as to what might be wrong.

def model(X, y, d):
    """
    X - the tensor of features, (minibatch_dim, feature_size)
    y - the tensor of labels, (minibatch_dim,)
    d - the domain (string)
    """
    w_mu_common, b_mu_common = Variable(torch.zeros(c, p)), Variable(torch.zeros(c))
    w_sigma_common, b_sigma_common = Variable(torch.zeros(c, p)), Variable(torch.zeros(c))

    w_mu_domain = pyro.sample(f'{d}_w_mu', dist.normal, w_mu_common, w_sigma_common)
    b_mu_domain = pyro.sample(f'{d}_b_mu', dist.normal, b_mu_common, b_sigma_common)

    w_sigma_domain = Variable(torch.ones(c, p))
    b_sigma_domain = Variable(torch.ones(c))

    w_prior = Normal(w_mu_domain, w_sigma_domain)
    b_prior = Normal(b_mu_domain, b_sigma_domain)
    priors = { 'linear.weight': w_prior, 'linear.bias': b_prior }

    lifted_module = pyro.random_module(f'{d}_module', models[d], priors)
    lifted_reg_model = lifted_module()

    with pyro.iarange('map', N, subsample=X):
        prediction_mean = lifted_reg_model(X).squeeze()
        pyro.sample('obs', Categorical(logits=prediction_mean), obs=y.squeeze())

def guide(X, y, d):
    w_mu_common = pyro.param('guide_m_w', Variable(torch.randn(c, p), requires_grad=True))
    w_log_sigma = softplus(
        pyro.param(
            'guide_s_w',
            Variable(-3. * torch.ones(c, p) + 0.05 * torch.randn(c, p), requires_grad=True)
        )
    )
    b_mu_common = pyro.param('guide_m_b', Variable(torch.randn(c), requires_grad=True))
    b_log_sigma = softplus(
        pyro.param(
            'guide_s_b',
            Variable(-3. * torch.ones(c) + 0.05 * torch.randn(c), requires_grad=True)
        )
    )
    w_mu_domain = pyro.sample(f'{d}_w_mu', dist.normal, w_mu_common, w_log_sigma)
    w_log_sigma_domain = softplus(
        pyro.param(
            f'{d}_s_w_guide',
            Variable((-3.0 * torch.ones(c, p) + 0.05 * torch.randn(c, p)), requires_grad=True)
        )
    )
    b_mu_domain = pyro.sample(f'{d}_b_mu', dist.normal, b_mu_common, b_log_sigma)
    b_log_sigma_domain = softplus(
        pyro.param(
            f'{d}_s_b_guide',
            Variable((-3.0 * torch.ones(c) + 0.05 * torch.randn(c)), requires_grad=True)
        )
    )

    w_dist_domain = Normal(w_mu_domain, w_log_sigma_domain)
    b_dist_domain = Normal(b_mu_domain, b_log_sigma_domain)
    dists = {'linear.weight': w_dist_domain, 'linear.bias': b_dist_domain}
    lifted_module = pyro.random_module(f'{d}_module', models[d], dists)

    return lifted_module()

Solved the problem – in the common prior, I was accidentally setting sigma=torch.zeros(), rather than torch.ones().

Yes, setting wrong parameters for distributions is the easiest way to get NaN (from my experience). I face this problem many times. :slight_smile:

1 Like