I’m trying to write a correct guide for a very simple hierarchical logistic model that just has an indexed intercept parameter (corresponding to different levels of a categorical variable) and a distribution that these intercepts come from (and of course hyperpriors for this distribution). But when comparing my SVI estimates with the HMC estimates, they are way off. I then tried an AutoDiagonalNormal guide and they were correct; so there must be something wrong with the manual guide I wrote. But I can’t figure out what it is.
I’ve confirmed that the model is written correctly because the HMC estimates are where they’re supposed to be (this is an adaptation of an example from Statistical Rethinking).
# Note: n is number of trials, surv is number of successes, tank is the incrementing index for the different groups/intercepts. def model_1(n, surv, tank): # Hyperpriors. a_bar = pyro.sample("a_bar", dist.Normal(0, 1.5)) sigma = pyro.sample("sigma", dist.Exponential(1)) # Adaptive prior (of intercepts). a = pyro.sample("a", dist.Normal(a_bar, sigma).expand([tank.shape])) # Linear model only has an intercept (indexed). logit_p = a[tank] # Generative distribution for the data. with pyro.plate('data'): pyro.sample("surv", dist.Binomial(n, logits=logit_p), obs=surv)
And my guide (which must be incorrect somewhere) is:
def guide_1(n, surv, tank): # Variational parameters for adaptive prior. a_bar_loc = pyro.param("a_bar_loc", torch.tensor(0.)) a_bar_scale = pyro.param("a_bar_scale", torch.tensor(1.), constraint=constraints.positive) sigma_rate = pyro.param("sigma_rate", torch.tensor(1.), constraint=constraints.positive) # Variational parameters for latent indexed intercept. a_loc = pyro.param("a_loc", torch.zeros(tank.shape)) a_scale = pyro.param("a_scale", torch.ones(tank.shape), constraint=constraints.positive) # Distributions based on variational parameters. pyro.sample("a_bar", dist.Normal(a_bar_loc, a_bar_scale)) pyro.sample("sigma", dist.Exponential(sigma_rate)) pyro.sample("a", dist.Normal(a_loc, a_scale))
Does anyone know what I could do to fix my guide so I get similar estimates that the AutoDiagonalNormal (or even a AutoMultivariateNormal) created? Thanks!