I am trying to use HMCGibbs sampling with more than one chain using chain_method=“vectorized”, but it there appears to be some problem with splitting the random keys.
Consider this toy example that I copied from the numpyro documentation, where I only changed the chain_method and the number of chains:
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMCGibbs
def model():
x = numpyro.sample("x", dist.Normal(0.0, 2.0))
y = numpyro.sample("y", dist.Normal(0.0, 2.0))
numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))
def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
y = hmc_sites['y']
new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
return {'x': new_x}
hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])
mcmc = MCMC(kernel, num_warmup=100, num_chains=2, num_samples=100, progress_bar=False,chain_method='vectorized',)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()
It returns as an error:
TypeError: split accepts a single key, but was given a key array ofshape (2,) != (). Use jax.vmap for batching.
Is there a way to use chain_method=“vectorized” for HMCGibbs using Numpyro?