How does HMCGibbs work?

Hi there,

I’ve been playing around with HMCGibbs. The exact code below isn’t what I’m using in my model, but it demonstrates what I’ve been finding in my more complex case. (My case uses a custom likelihood with numpyro.factor, so if it’s of interest, I can show my more complicated model where the problem occurs.)

def model():
    x = numpyro.sample("x", dist.Uniform(-100000, -100)) # option 1
    # x = numpyro.sample("x", dist.Uniform(100, 1000)) # option 2
    y = numpyro.sample("y", dist.HalfNormal(1.0))
    numpyro.sample("obs", dist.Poisson(x - y), obs=jnp.array([1.0]))

def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    y = hmc_sites['y']
    new_x = dist.Uniform(10, 11).sample(rng_key)
    return {'x': new_x}

hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])

My understanding was that the sample statements in model() for sites that are going to be replaced in gibbs_fn (here x), are just placeholders – I could put whatever distribution I want. What I’m finding is that numpyro needs the sample statement in model() to at least be physically correct. Here, option 1 does not work because the rate for Poisson must be positive. Option 2 does work because it keeps rate positive, even though new_x replaces x when using HMCGibbs.

Why it is necessary for x to be sensible, even when it is replaced in sampling during the gibbs step?

Best,

Theo

You can use ImproperUniform if you don’t want to use a proper distribution. Your gibbs_fn does not utilize the prior density of x but in other cases like DiscreteHMCGibbs, we need to utilize the density. I guess your main question is why we need the gibbs proposal to be in the support of the prior? We currently don’t have mechanism to automatically assign -inf to those out-of-support values - and I guess that the inference won’t work with -inf values. If you know that your proposals are in a different domain, you should adjust prior accordingly (even if you don’t have any information else, you can use improper ones). You can also use ImproperUniform(real) to cover all the range.