HMCGibbs with chain_method=''vectorized"

I am trying to use HMCGibbs sampling with more than one chain using chain_method=“vectorized”, but it there appears to be some problem with splitting the random keys.

Consider this toy example that I copied from the numpyro documentation, where I only changed the chain_method and the number of chains:

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMCGibbs

def model():
    x = numpyro.sample("x", dist.Normal(0.0, 2.0))
    y = numpyro.sample("y", dist.Normal(0.0, 2.0))
    numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    y = hmc_sites['y']
    new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
    return {'x': new_x}

hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])
mcmc = MCMC(kernel, num_warmup=100, num_chains=2, num_samples=100, progress_bar=False,chain_method='vectorized',)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()

It returns as an error:

TypeError: split accepts a single key, but was given a key array ofshape (2,) != (). Use jax.vmap for batching.

Is there a way to use chain_method=“vectorized” for HMCGibbs using Numpyro?

Could you create an issue on github? I think we can either raise an error there or try to support jax.vmap directly in the mcmc interface.

any update on this? i still am experiencing this issue

I guess this is resolved via chain_method=jax.vmap

1 Like