I’m working on implementing Chris Manning’s Hierarchical Bayesian Domain Adaptation (https://nlp.stanford.edu/pubs/hbda.pdf) in Pyro, and I’m having an issue with the gradient/loss returning nan.
The model (center image at the top of the 3rd page in the PDF) has a common normal distribution, and then a domain-specific normal distribution (who’s mean is sampled from the common distribution), and that distribution acts as a prior on the weights.
I’ve included my code below – if I don’t include the domain-specific prior and just use the common distribution as the prior on the weights, it runs fine. However, once I include the domain-specific prior, svi.step()
returns nan
, as a result of the gradients returning nan
. Additionally, if I don’t include any params in the common prior, and set all the variables to require_grad=False
, it runs but returns garbage results. Any feedback would be appreciated as to what might be wrong.
def model(X, y, d):
"""
X - the tensor of features, (minibatch_dim, feature_size)
y - the tensor of labels, (minibatch_dim,)
d - the domain (string)
"""
w_mu_common, b_mu_common = Variable(torch.zeros(c, p)), Variable(torch.zeros(c))
w_sigma_common, b_sigma_common = Variable(torch.zeros(c, p)), Variable(torch.zeros(c))
w_mu_domain = pyro.sample(f'{d}_w_mu', dist.normal, w_mu_common, w_sigma_common)
b_mu_domain = pyro.sample(f'{d}_b_mu', dist.normal, b_mu_common, b_sigma_common)
w_sigma_domain = Variable(torch.ones(c, p))
b_sigma_domain = Variable(torch.ones(c))
w_prior = Normal(w_mu_domain, w_sigma_domain)
b_prior = Normal(b_mu_domain, b_sigma_domain)
priors = { 'linear.weight': w_prior, 'linear.bias': b_prior }
lifted_module = pyro.random_module(f'{d}_module', models[d], priors)
lifted_reg_model = lifted_module()
with pyro.iarange('map', N, subsample=X):
prediction_mean = lifted_reg_model(X).squeeze()
pyro.sample('obs', Categorical(logits=prediction_mean), obs=y.squeeze())
def guide(X, y, d):
w_mu_common = pyro.param('guide_m_w', Variable(torch.randn(c, p), requires_grad=True))
w_log_sigma = softplus(
pyro.param(
'guide_s_w',
Variable(-3. * torch.ones(c, p) + 0.05 * torch.randn(c, p), requires_grad=True)
)
)
b_mu_common = pyro.param('guide_m_b', Variable(torch.randn(c), requires_grad=True))
b_log_sigma = softplus(
pyro.param(
'guide_s_b',
Variable(-3. * torch.ones(c) + 0.05 * torch.randn(c), requires_grad=True)
)
)
w_mu_domain = pyro.sample(f'{d}_w_mu', dist.normal, w_mu_common, w_log_sigma)
w_log_sigma_domain = softplus(
pyro.param(
f'{d}_s_w_guide',
Variable((-3.0 * torch.ones(c, p) + 0.05 * torch.randn(c, p)), requires_grad=True)
)
)
b_mu_domain = pyro.sample(f'{d}_b_mu', dist.normal, b_mu_common, b_log_sigma)
b_log_sigma_domain = softplus(
pyro.param(
f'{d}_s_b_guide',
Variable((-3.0 * torch.ones(c) + 0.05 * torch.randn(c)), requires_grad=True)
)
)
w_dist_domain = Normal(w_mu_domain, w_log_sigma_domain)
b_dist_domain = Normal(b_mu_domain, b_log_sigma_domain)
dists = {'linear.weight': w_dist_domain, 'linear.bias': b_dist_domain}
lifted_module = pyro.random_module(f'{d}_module', models[d], dists)
return lifted_module()