No worries, I’ve re-run the following code:
import numpy as np
import numpyro
from numpyro.infer import MCMC, DiscreteHMCGibbs, HMC, init_to_value
import numpyro.distributions as dist
from jax import random
def model():
index = numpyro.sample('index', dist.Categorical(probs=np.array([[0.2,0.2,0.2,0.2,0.2]])))
mu = numpyro.sample('mu', dist.Normal(0,1), sample_shape=(2,))
return
kernel = DiscreteHMCGibbs(HMC(model), init_strategy=init_to_value(values={'index':np.array([1,4,2,3,0])}))
mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1, progress_bar=False)
print('sampler: ', mcmc.sampler, ', init strategy: ', mcmc.sampler._init_strategy)
mcmc._compile(random.PRNGKey(0))
print('\nlast state: ', mcmc._last_state.z)
mcmc.warmup(random.PRNGKey(0),collect_warmup=True)
print('\nfirst iter of warmup: ', (mcmc.get_samples())['index'][0,:])
This outputs:
sampler: <numpyro.infer.hmc_gibbs.DiscreteHMCGibbs object at 0x7fbc005ad460> , init strategy: functools.partial(<function init_to_value at 0x7fbbc0a44f70>, values={'index': array([1, 4, 2, 3, 0])})
last state: {'index': DeviceArray([1, 4, 2, 3, 0], dtype=int32), 'mu': DeviceArray([-0.3011737, 0.684855 ], dtype=float32)}
first iter of warmup: [2 3 3 3 4]
It looks like the last state of the compilation is correct, but the first iteration of the warmup is different.
Does the first iteration of the warmup phase correspond to the next sample, when starting at the specified initial value?