AIES does not mix chains when there are few walkers

I’m testing the AIES sampler, that should behave similarly to emcee, with a very simple 1-D Gaussian model built as follows.

import jax.numpy as jnp
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, AIES, ESS
import arviz as az

sample_key, data_key, noise_key  = jax.random.split(jax.random.PRNGKey(18), 3)
data = jax.random.normal(data_key, (5000,))
noise = jax.random.normal(noise_key, (5000,))*0.01
data += noise

def loglike(mu, sigma):
  chi = (data - mu) / sigma
  log_norm = -0.5 * jnp.log(2 * jnp.pi * sigma**2)
  return jnp.sum(log_norm - 0.5 * chi**2)

priors = {
  "mu": dist.Uniform(-5.,5.),
  "sigma": dist.Uniform(0.,10.),
}

def model():
  params = {k : numpyro.sample(k, priors[k]) for k in priors.keys()}
  log_like =  loglike(**params)
  numpyro.factor(“log_likelihood”, log_like)

Even if the model is very simple, the AIES sampler with the minimum number of walkers, 4 since the parameter space is 2-D, does not mix the walkers. The result is a chain with the same values repeated.

kernel = AIES(model)
mcmc = MCMC(kernel,
  num_warmup=5000,
  num_samples=5000,
  num_chains=4,
  progress_bar=True,
  chain_method='vectorized'
)

mcmc.run(sample_key)
res = az.from_numpyro(mcmc)
az.plot_trace(res) # → UserWarning: Your data appears to have a single value or no finite values, walkers do not walkres
az.summary(res) # → rhat NaN

The same problem occurs for 6 walkers. The sampler starts working with 8/12 walkers. On contrary, the ESS sampler works perfectly already with 4 walkers.

kernel = ESS(model)
mcmc = MCMC(kernel,
  num_warmup=5000,
  num_samples=5000,
  num_chains=4,
  progress_bar=True,
  chain_method='vectorized'
)

mcmc.run(sample_key)
res = az.from_numpyro(mcmc)
az.plot_trace(res) # nice results!
az.summary(res)  # rhat = 1.0, ess very high

Finally, emcee works okay (at least mixes the chains correctly) already with 4 walkers.

import emcee
import numpy as np
import arviz as az

def log_prior_emcee(mu, sigma):
  condition = (mu < 5.) & (mu > -5.) & (sigma > 0.) & (sigma < 10.)
  return np.where(condition, 0., -np.inf)

def loglike_emcee(params):
  mu = params[:,0]
  sigma = params[:,1]
  log_p = log_prior_emcee(mu, sigma) 
  chi = (data - mu[:,None]) / sigma[:,None]
  log_norm = -0.5 * jnp.log(2 * jnp.pi * sigma2)
  ll = np.sum(log_norm[:,None] - 0.5 * chi2 + log_p[:,None], axis = -1)
  return ll

nwalkers = 4
p0mu = np.random.uniform(-5,5,size=nwalkers) 
p0sigma = np.random.uniform(0,10,size=nwalkers)
p0 = np.array([(a,b) for a,b in zip(p0mu, p0sigma)])

sampler = emcee.EnsembleSampler(nwalkers=nwalkers, 
  ndim=2, 
  log_prob_fn=loglike_emcee, 
  vectorize=True
)
print("warmup")
sampler.run_mcmc(p0, 5000, progress=True)
p0_after_warmup=sampler.get_last_sample()
sampler.reset()
print("sample")
sampler.run_mcmc(p0_after_warmup, 5000, progress=True)
res = az.from_emcee(sampler)
az.plot_trace(res)
print(az.summary(res))

The problem is solves by setting randomize_split=True, i.e. by defining the AIES kernel as

kernel = AIES(model, randomize_split=True)

Would it be better to set the default value of randomize_split to True as in ESS kernel?

cc @amifalk

Seems fine to me if you want to open a PR to change the behavior - not sure why I didn’t write it like that initially. Looks like emcee also defaults to randomize_split=True for its moves. As an aside, even though the recommended minimum number of chains is 2x the dimension of the space, you will likely want more than that to ensure good mixing.