Mcmc.get_samples() returns empty dict

I got the infer_discrete to work as well, but it was a little bit of a journey. I was stumped since the examples and discussion for infer_discrete were always about a model with continuous and discrete variables. Therefore the procedure was always (source):

  1. Have model p(data | discrete, continuous), marginalize and run MCMC to get p(continuous | data)
  2. Using infer_discrete to get samples from p(discrete | data, continuous)

Since the mystery model has only discrete variables, mcmc.get_samples() is an empty dictionary. It didn’t make a lot of sense to me what to do, but I just ended up just pushing through, and realized you don’t even have to run mcmc, can pass everything to infer_discrete with just an empty dictionary. This seems a little hacky…

Below is the minimum amount of code to run with infer_discrete using infer_discrete_model snippet from Example: Bayesian Models of Annotation:

import jax
import jax.numpy as jnp
import numpyro
from numpyro.contrib.funsor import config_enumerate, infer_discrete
import numpyro.distributions as dist

def infer_discrete_model(rng_key, samples):
    conditioned_model = numpyro.handlers.condition(model, data=samples)
    infer_discrete_model = infer_discrete(
        config_enumerate(conditioned_model), rng_key=rng_key
    )
    with numpyro.handlers.trace() as tr:
        infer_discrete_model(*data)

    return {
        name: site["value"]
        for name, site in tr.items()
        if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
    }

guess = 0.7

def model(guess, weapon):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]), obs=weapon)
    
data = (guess, 0.)

num_samples = 4000

discrete_samples = jax.vmap(infer_discrete_model)(
    jax.random.split(jax.random.PRNGKey(1), num_samples), {}
)

discrete_samples["murderer"].mean(), discrete_samples["murderer"].std()

Output:

(DeviceArray(0.353, dtype=float32), DeviceArray(0.47790274, dtype=float32))

Some references that helped me get something to work:

1 Like