Hi,
I was playing around with the Intro to Pyro Tutorial and noticed that the AutoNormal guide seems to greatly outperform the custom guide that is supposed to be identical to it (the only difference possibly being different starting values for the parameters).
Both guides are supposed to be diagonal Gaussians but the AutoNormal guide seems to be getting to a better optimal posterior than the custom guide and doing so much faster. I was looking through the AutoNormal source code but can’t figure out why it would be performing better, and I’d be surprised if different starting values of the parameters would have that much of a difference in this simple model (and the starting values aren’t that bad in the custom guide). So not sure if anyone can shed light on the gap in performance?
The code for the custom guide is
def custom_guide(is_cont_africa, ruggedness, log_gdp=None):
# Variational parameters.
intercept_loc = pyro.param('intercept_loc', lambda: torch.tensor(0.))
intercept_scale = pyro.param('intercept_scale', lambda: torch.tensor(1.), constraint=constraints.positive)
weights_loc = pyro.param('weights_loc', lambda: torch.randn(3))
weights_scale = pyro.param('weights_scale', lambda: torch.ones(3), constraint=constraints.positive)
sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.), constraint=constraints.positive)
# Variational distributions (each is a Gaussian; combined they would have diagonal covariance).
intercept = pyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale))
beta_africa = pyro.sample('beta_africa', dist.Normal(weights_loc[0], weights_scale[0]))
beta_rugged = pyro.sample('beta_rugged', dist.Normal(weights_loc[1], weights_scale[1]))
beta_interact = pyro.sample('beta_interact', dist.Normal(weights_loc[2], weights_scale[2]))
sigma = pyro.sample('sigma', dist.Normal(sigma_loc, torch.tensor(0.05)))
return {'intercept': intercept, 'beta_africa': beta_africa, 'beta_rugged': beta_rugged,
'beta_interact': beta_interact, 'sigma': sigma}
and the code for the model is
def bayesian_regression(is_cont_africa, ruggedness, log_gdp=None):
# Priors.
intercept = pyro.sample('intercept', dist.Normal(0., 10.))
beta_africa = pyro.sample('beta_africa', dist.Normal(0., 1.))
beta_rugged = pyro.sample('beta_rugged', dist.Normal(0., 1.))
beta_interact = pyro.sample('beta_interact', dist.Normal(0., 1.))
sigma = pyro.sample('sigma', dist.Uniform(0., 10.))
# Linear model.
mean = intercept + beta_africa * is_cont_africa + beta_rugged * ruggedness +\
beta_interact * is_cont_africa * ruggedness
# Likelihood.
with pyro.plate('data', len(ruggedness)):
return pyro.sample('obs', dist.Normal(mean, sigma), obs=log_gdp)