I am fitting a large hierarchical model to some data and I think it would benefit from HMC within Gibbs sampling, is there a way to specify a group of parameters within a model to be automatically Gibbs sampled across?
As a toy example take this model:
def model(data_x, data_y):
mu_a = numpyro.sample('mu_a', dist.Normal(0, 100))
sigma_a = numpyro.sample('sigma_a', dist.HalfNormal(100))
mu_b = numpyro.sample('mu_b', dist.Normal(0, 5))
sigma_b = numpyro.sample('sigma_b', dist.HalfNormal(3))
with numpyro.plate('points', data_y.shape[0]):
a = numpyro.sample('a', dist.Normal(mu_a, sigma_a))
b = numpyro.sample('b', dist.Normal(mu_b, sigma_b))
est_y = a * data_x + b
sigma = numpyro.sample('sigma', dist.HalfNormal(50))
with numpyro.plate("data", data_y.shape[0]):
numpyro.sample("obs", dist.Normal(est_y, sigma), obs=data_y)
I want a way to have the sampler first take an HMC step with mu_a, sigma_a, mu_b, sigma_b
and keep a, b, sigma
fixed, followed by an HMC step with a, b, sigma
and keep mu_a, sigma_a, mu_b, sigma_b
fixed. Ideally, I would like to do this without having to rewrite the model.
HMCECS does something similar, but I want to be able to pick what parameters are in the Gibbs step rather than it be a random subset from the plate (in the full model some of the parameters I want to Gibbs over are not in plates). I tried looking at the HMCECS source code to see how it splits the model up, but can’t really follow what is going on.
Is there an easy way to do this in Numpyro?