HMC Gibbs across subset of model paramters

I am fitting a large hierarchical model to some data and I think it would benefit from HMC within Gibbs sampling, is there a way to specify a group of parameters within a model to be automatically Gibbs sampled across?

As a toy example take this model:

def model(data_x, data_y):
    mu_a = numpyro.sample('mu_a', dist.Normal(0, 100))
    sigma_a = numpyro.sample('sigma_a', dist.HalfNormal(100))
    mu_b = numpyro.sample('mu_b', dist.Normal(0, 5))
    sigma_b = numpyro.sample('sigma_b', dist.HalfNormal(3))

    with numpyro.plate('points', data_y.shape[0]):
        a = numpyro.sample('a', dist.Normal(mu_a, sigma_a))
        b = numpyro.sample('b', dist.Normal(mu_b, sigma_b))

    est_y = a * data_x + b
    sigma = numpyro.sample('sigma', dist.HalfNormal(50))

    with numpyro.plate("data", data_y.shape[0]):
        numpyro.sample("obs", dist.Normal(est_y, sigma), obs=data_y)

I want a way to have the sampler first take an HMC step with mu_a, sigma_a, mu_b, sigma_b and keep a, b, sigma fixed, followed by an HMC step with a, b, sigma and keep mu_a, sigma_a, mu_b, sigma_b fixed. Ideally, I would like to do this without having to rewrite the model.

HMCECS does something similar, but I want to be able to pick what parameters are in the Gibbs step rather than it be a random subset from the plate (in the full model some of the parameters I want to Gibbs over are not in plates). I tried looking at the HMCECS source code to see how it splits the model up, but can’t really follow what is going on.

Is there an easy way to do this in Numpyro?

1 Like

there is HMCGibbs but you need to implement your own custom gibbs step (numpyro handles the hmc)

https://num.pyro.ai/en/stable/mcmc.html?highlight=hmcgibbs#numpyro.infer.hmc_gibbs.HMCGibbs

From the sounds of the HMCECS documentation, it can create the custom Gibbs step automatically from the model, that is the kind of solution I am looking for if possible.

But failing that, how would you go about making a custom Gibbs step for the above model? It is not clear from the HMCGibbs documentation if it needs to be a fully analytic Gibbs step, or if can be another HMC step but with different parameters. It sounds like it is set up to do HMC followed by an analytic update, not HMC followed by HMC.

We have this issue: Make Gibbs kernels composable · Issue #898 · pyro-ppl/numpyro · GitHub which allows such customization. I still believe that the solution is simple and similar to what we did in, e.g., HMCECS.

well i don’t think doing hmc followed by hmc would buy you anything.

if that model is the model you actually care about then presumably sampling is more or less easy with hmc out of the box (though it’d probably work better if you re-normalized things so large numbers like 100 don’t appear) unless data_y.shape[0] is large and thus the dimensions of a and b gets large.

however a and b are amenable to a gibbs step: the conditional posterior of a and b is gaussian. so you would just need to compute the mean and covariance matrix of that gaussian as a function of the other random variables and then follow the pattern in the docs

in any case i’d try massaging your model first before trying anything fancy: Bad posterior geometry and how to deal with it — NumPyro documentation

The full problem is an astronomy gravitational lensing problem. One set of parameters is for the unknown mass distribution (about 10 free parameters), and the others are for the light behind the lensing galaxy (a 100x100 pixel grid, so 10,000 free parameters with a Gaussian random field prior).

The mass tells the light from the pixel grid how to bend and create a model of the observed data. Because of how the physics works I want to split up the model to first hold the mass parameters fixed and update the light pixel grid, then hold the light fixed and update the mass. Because of how the physics works, a random change to all the parameters at once is not likely to be consistent with the data and will lead to a worse likelihood. (i.e. I want to split it this way because of how the physics works).

For simple lens systems, HMC just works out of the box, but for sufficiently complex systems it struggles to converge. In the complex system, if we hold the mass params fixed to “good enough” values (found with SVI) the HMC converges without issue, so this is an initial indication that splitting the sampling up could help.

I agree the toy model I put in the example is quite simple and would work as is (after reparameterizing). Also, I have already tried all the suggestions on the “Bad posterior geometry” documentation page without any of them helping.

well if you can do a gibbs step on the gaussian random field i suspect that would work quite well. i’m also unconvinced that interleaving HMC steps will get you much of anything. one of the problems with HMC in high dimensions is that in order for it to work well it needs a pretty well-tuned step size and mass matrix but that can be hard to come by, even if you do a long warm-up. by contrast gibbs steps are parameter free. note that the gibbs_fn gets a random key so you can use that to choose a subset of the random field to update in each step (to limit the cost of the linear algebra).

I just wanted to write a quick update on this. I spent some time understanding how the HMCGibbs function works and was able to implement a version that takes alternating HMC steps across sets of parameters for a single model (in fact I wrote it in a way to take a list of HMC kernels and a list of parameters and loop over each set in turn hiding all other sets fixed) and found it works remarkable well for my problem.

With standard HMC on the full model with a large max_tree_depth and large target_accept_prob and large num_warmup it would take close to a day to draw a few thousand samples and the results were nowhere close to converged. With the new custom “multi-HMC Gibbs” sampler it can give me 5000 draws from 4 chains and fully converge in 20 mins.

Hopefully in the next few weeks, I can find some time to put this together as a PR with some examples in case others find it useful.

1 Like