I’m testing the AIES sampler, that should behave similarly to emcee, with a very simple 1-D Gaussian model built as follows.
import jax.numpy as jnp
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, AIES, ESS
import arviz as az
sample_key, data_key, noise_key = jax.random.split(jax.random.PRNGKey(18), 3)
data = jax.random.normal(data_key, (5000,))
noise = jax.random.normal(noise_key, (5000,))*0.01
data += noise
def loglike(mu, sigma):
chi = (data - mu) / sigma
log_norm = -0.5 * jnp.log(2 * jnp.pi * sigma**2)
return jnp.sum(log_norm - 0.5 * chi**2)
priors = {
"mu": dist.Uniform(-5.,5.),
"sigma": dist.Uniform(0.,10.),
}
def model():
params = {k : numpyro.sample(k, priors[k]) for k in priors.keys()}
log_like = loglike(**params)
numpyro.factor(“log_likelihood”, log_like)
Even if the model is very simple, the AIES sampler with the minimum number of walkers, 4 since the parameter space is 2-D, does not mix the walkers. The result is a chain with the same values repeated.
kernel = AIES(model)
mcmc = MCMC(kernel,
num_warmup=5000,
num_samples=5000,
num_chains=4,
progress_bar=True,
chain_method='vectorized'
)
mcmc.run(sample_key)
res = az.from_numpyro(mcmc)
az.plot_trace(res) # → UserWarning: Your data appears to have a single value or no finite values, walkers do not walkres
az.summary(res) # → rhat NaN
The same problem occurs for 6 walkers. The sampler starts working with 8/12 walkers. On contrary, the ESS sampler works perfectly already with 4 walkers.
kernel = ESS(model)
mcmc = MCMC(kernel,
num_warmup=5000,
num_samples=5000,
num_chains=4,
progress_bar=True,
chain_method='vectorized'
)
mcmc.run(sample_key)
res = az.from_numpyro(mcmc)
az.plot_trace(res) # nice results!
az.summary(res) # rhat = 1.0, ess very high
Finally, emcee works okay (at least mixes the chains correctly) already with 4 walkers.
import emcee
import numpy as np
import arviz as az
def log_prior_emcee(mu, sigma):
condition = (mu < 5.) & (mu > -5.) & (sigma > 0.) & (sigma < 10.)
return np.where(condition, 0., -np.inf)
def loglike_emcee(params):
mu = params[:,0]
sigma = params[:,1]
log_p = log_prior_emcee(mu, sigma)
chi = (data - mu[:,None]) / sigma[:,None]
log_norm = -0.5 * jnp.log(2 * jnp.pi * sigma2)
ll = np.sum(log_norm[:,None] - 0.5 * chi2 + log_p[:,None], axis = -1)
return ll
nwalkers = 4
p0mu = np.random.uniform(-5,5,size=nwalkers)
p0sigma = np.random.uniform(0,10,size=nwalkers)
p0 = np.array([(a,b) for a,b in zip(p0mu, p0sigma)])
sampler = emcee.EnsembleSampler(nwalkers=nwalkers,
ndim=2,
log_prob_fn=loglike_emcee,
vectorize=True
)
print("warmup")
sampler.run_mcmc(p0, 5000, progress=True)
p0_after_warmup=sampler.get_last_sample()
sampler.reset()
print("sample")
sampler.run_mcmc(p0_after_warmup, 5000, progress=True)
res = az.from_emcee(sampler)
az.plot_trace(res)
print(az.summary(res))