Simple Hierarchical Guide

Hi,

I’m trying to write a correct guide for a very simple hierarchical logistic model that just has an indexed intercept parameter (corresponding to different levels of a categorical variable) and a distribution that these intercepts come from (and of course hyperpriors for this distribution). But when comparing my SVI estimates with the HMC estimates, they are way off. I then tried an AutoDiagonalNormal guide and they were correct; so there must be something wrong with the manual guide I wrote. But I can’t figure out what it is.

I’ve confirmed that the model is written correctly because the HMC estimates are where they’re supposed to be (this is an adaptation of an example from Statistical Rethinking).

    # Note: n is number of trials, surv is number of successes, tank is the incrementing index for the different groups/intercepts.
    def model_1(n, surv, tank):
        # Hyperpriors.
        a_bar = pyro.sample("a_bar", dist.Normal(0, 1.5))
        sigma = pyro.sample("sigma", dist.Exponential(1))
        # Adaptive prior (of intercepts).
        a = pyro.sample("a", dist.Normal(a_bar, sigma).expand([tank.shape[0]]))
        # Linear model only has an intercept (indexed).
        logit_p = a[tank]
        # Generative distribution for the data.
        with pyro.plate('data'):
            pyro.sample("surv", dist.Binomial(n, logits=logit_p), obs=surv)

And my guide (which must be incorrect somewhere) is:

    def guide_1(n, surv, tank):
        # Variational parameters for adaptive prior.
        a_bar_loc = pyro.param("a_bar_loc", torch.tensor(0.))
        a_bar_scale = pyro.param("a_bar_scale", torch.tensor(1.), constraint=constraints.positive)
        sigma_rate = pyro.param("sigma_rate", torch.tensor(1.), constraint=constraints.positive)
        # Variational parameters for latent indexed intercept.
        a_loc = pyro.param("a_loc", torch.zeros(tank.shape[0]))
        a_scale = pyro.param("a_scale", torch.ones(tank.shape[0]), constraint=constraints.positive)
        # Distributions based on variational parameters.
        pyro.sample("a_bar", dist.Normal(a_bar_loc, a_bar_scale))
        pyro.sample("sigma", dist.Exponential(sigma_rate))
        pyro.sample("a", dist.Normal(a_loc, a_scale))

Does anyone know what I could do to fix my guide so I get similar estimates that the AutoDiagonalNormal (or even a AutoMultivariateNormal) created? Thanks!

Your guide looks reasonable to me. Could you try the following tunes:

  • reduce the initial values of scale parameters to 0.1
  • use LogNormal guide for sigma?

Wow thanks! That actually looks to have fixed it. It looks like the Exponential distribution was the main issue; possibly its fatter tails combined with the logit link was preventing the SVI from finding good optimal estimates (not sure but maybe extreme values in the tails were being squashed by the logistic to still have reasonable probability).

For record, the guide that works is below.

Thanks for your help!

def guide_1(n, surv, tank):
    # Variational parameters for hyperparameters (of adaptive prior).
    a_bar_loc = pyro.param("a_bar_loc", torch.tensor(0.))
    a_bar_scale = pyro.param("a_bar_scale", torch.tensor(0.1), constraint=constraints.positive)
    sigma_loc = pyro.param("sigma_loc", torch.tensor(0.1), constraint=constraints.positive)
    sigma_scale = pyro.param("sigma_scale", torch.tensor(0.1), constraint=constraints.positive)
    # Variational parameters for latent parameters (intercepts).
    a_loc = pyro.param("a_loc", torch.zeros(tank.shape[0]))
    a_scale = pyro.param("a_scale", torch.ones(tank.shape[0]), constraint=constraints.positive)
    # Hyperparameter and latent parameter distributions.
    pyro.sample("a_bar", dist.Normal(a_bar_loc, a_bar_scale))
    pyro.sample("sigma", dist.LogNormal(sigma_loc, sigma_scale))
    pyro.sample("a", dist.Normal(a_loc, a_scale))