Initalize Discrete Variables in MCMC

Hi all,

I was wondering if it is possible to initialize discrete variables in DiscreteHMCGibbs similar to how continuous variables can be initialized using HMC/NUTS? Is there something equivalent to setting init_strategy = init_to_value()?

Thank you!

Hi @emyjo, currently we don’t have init strategy for gibbs sites but it is easy to support (just need to expose the init strategy argument) and use it here. Please feel free to make a FR or submit a PR for this.

Hi @fehiepsi, thank you for the quick response! I have tried to implement your suggestion, but I’m unsure if it is working properly. Some example code is below:

import numpy as np
import numpyro
from numpyro.infer import MCMC, DiscreteHMCGibbs, HMC, init_to_value
import numpyro.distributions as dist
from jax import random

def model():
    index = numpyro.sample('index', dist.Categorical(probs=np.array([[0.2,0.2,0.2,0.2,0.2]])))
    mu = numpyro.sample('mu', dist.Normal(0,1), sample_shape=(2,))
    return

kernel = DiscreteHMCGibbs(HMC(model), init_strategy=init_to_value(values={'index':np.array([1,4,2,3,0])}))
mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1, progress_bar=True)
mcmc.warmup(random.PRNGKey(0),collect_warmup=True)

print('sampler: ', mcmc.sampler, ', init strategy: ', mcmc.sampler._init_strategy)
print((mcmc.get_samples())['index'][0,:])

Which outputs:
sampler: <numpyro.infer.hmc_gibbs.DiscreteHMCGibbs object at 0x7ff9385132e0> , init strategy: functools.partial(<function init_to_value at 0x7ff918a04f70>, values={'index': array([1, 4, 2, 3, 0])})
[2 3 3 3 4]

I would have expected that the first iteration of the warmup would be equal to the specified values, i.e. that the printed array should be [1 4 2 3 0] instead of [2 3 3 3 4]. Is this working as expected?

Hi @emyjo, maybe you’ll also need to update init_strategy here?

Hi @fehiepsi, sorry for not being clear, I updated both the DiscreteHMCGibbs and the HMCGibbs classes just in case, the hmc_gibbs.py file is now as follows (have skipped over unchanged code):

class HMCGibbs(MCMCKernel):

    def __init__(self, inner_kernel, gibbs_fn, gibbs_sites, init_strategy):
        if not isinstance(inner_kernel, HMC):
            raise ValueError("inner_kernel must be a HMC or NUTS sampler.")
        if not callable(gibbs_fn):
            raise ValueError("gibbs_fn must be a callable")
        assert (
            inner_kernel.model is not None
        ), "HMCGibbs does not support models specified via a potential function."

        self.inner_kernel = copy.copy(inner_kernel)
        self.inner_kernel._model = partial(_wrap_model, inner_kernel.model)
        self._gibbs_sites = gibbs_sites
        self._gibbs_fn = gibbs_fn
        self._prototype_trace = None
        self._init_strategy = init_strategy

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
        if self._prototype_trace is None:
            rng_key, key_u = random.split(rng_key)
            # We use init strategy to get around ImproperUniform which does not have
            # sample method.
            self._prototype_trace = trace(
                substitute(seed(self.model, key_u), substitute_fn=self._init_strategy)
            ).get_trace(*model_args, **model_kwargs)
:
:
:

class DiscreteHMCGibbs(HMCGibbs):

    def __init__(self, inner_kernel, *, random_walk=False, modified=False, init_strategy=init_to_sample):
        super().__init__(inner_kernel, identity, None, init_strategy)
        self._random_walk = random_walk
        self._modified = modified
        self._init_strategy = init_strategy
:
:
    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
        rng_key, key_u = random.split(rng_key)
        # We use init strategy to get around ImproperUniform which does not have
        # sample method.
        self._prototype_trace = trace(
            substitute(seed(self.model, key_u), substitute_fn=self._init_strategy)
        ).get_trace(*model_args, **model_kwargs)


With those changes made, I am seeing the issue mentioned above. Is there something else that I have missed?

Sorry, I think we’ll also need to look at

mcmc._compile(random.PRNGKey(0))
print(mcmc._last_state["z"])

No worries, I’ve re-run the following code:

import numpy as np
import numpyro
from numpyro.infer import MCMC, DiscreteHMCGibbs, HMC, init_to_value
import numpyro.distributions as dist
from jax import random

def model():
    index = numpyro.sample('index', dist.Categorical(probs=np.array([[0.2,0.2,0.2,0.2,0.2]])))
    mu = numpyro.sample('mu', dist.Normal(0,1), sample_shape=(2,))
    return

kernel = DiscreteHMCGibbs(HMC(model), init_strategy=init_to_value(values={'index':np.array([1,4,2,3,0])}))
mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1, progress_bar=False)

print('sampler: ', mcmc.sampler, ', init strategy: ', mcmc.sampler._init_strategy)

mcmc._compile(random.PRNGKey(0))
print('\nlast state: ', mcmc._last_state.z)

mcmc.warmup(random.PRNGKey(0),collect_warmup=True)
print('\nfirst iter of warmup: ', (mcmc.get_samples())['index'][0,:])

This outputs:

sampler: <numpyro.infer.hmc_gibbs.DiscreteHMCGibbs object at 0x7fbc005ad460> , init strategy: functools.partial(<function init_to_value at 0x7fbbc0a44f70>, values={'index': array([1, 4, 2, 3, 0])})

last state: {'index': DeviceArray([1, 4, 2, 3, 0], dtype=int32), 'mu': DeviceArray([-0.3011737, 0.684855 ], dtype=float32)}

first iter of warmup: [2 3 3 3 4]

It looks like the last state of the compilation is correct, but the first iteration of the warmup is different.

Does the first iteration of the warmup phase correspond to the next sample, when starting at the specified initial value?

the first iteration of the warmup phase correspond to the next sample, when starting at the specified initial value

That’s right! :slight_smile: