Mcmc.get_samples() returns empty dict

Hi all,

I am coding the example from the MBML book, chapter 1. I am expecting to have samples within my mcmc, and I don’t think there is an issue with my model definition (maybe?) since I can just sample the model and obtain the correct conditioning as well as the correct answer.

Am I making an obvious mistake?

# Min example of a mystery
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS

key = jax.random.PRNGKey(2)

guess = 0.7


def mystery(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    return murderer, weapon


conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})

nuts_kernel = NUTS(conditioned_model)

mcmc = MCMC(nuts_kernel, num_warmup=200, num_samples=200, num_chains=4)
mcmc.run(key, guess)

# mcmc.print_summary()
print(f"\n{mcmc.get_samples()=}")

with numpyro.handlers.seed(rng_seed=0):
    samples = []
    for _ in range(1000):
        samples.append(
            tuple(
                [
                    sample.item() if hasattr(sample, "item") else sample
                    for sample in conditioned_model(guess)
                ]
            )
        )

samples = pd.DataFrame(samples, columns=["murderer", "weapon"])

print(pd.crosstab(samples.murderer, samples.weapon, normalize="all"))

Output:

sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:01<00:00, 225.78it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:00<00:00, 7121.75it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:00<00:00, 6858.40it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:00<00:00, 7312.40it/s, 1 steps of size 1.19e+37. acc. prob=1.00]

mcmc.get_samples()={}
weapon    0.0
murderer     
0         0.327
1         0.673

relevant environment information:

python --version = Python 3.9.5
numpyro                   0.6.0                    pypi_0    pypi
funsor                    0.4.1                    pypi_0    pypi
jax                       0.2.16                   pypi_0    pypi
jaxlib                    0.1.68                   pypi_0    pypi

Edit
I noticed I had a mistake with using the un-conditioned model mystery instead of the conditioned_model within NUTS, but running with either still results in an empty dictionary from mcmc.get_samples()

I got the mcmc to run with a fully loaded get_samples() using DiscreteHMCGibbs, but I am unsure why NUTS alone wont work. I am guessing it has to do with enumeration of latent variables when reviewing the examples under Discrete Latent Variables from the docs. Some how I got the impression and assumed incorrectly that enumeration was only for SVI when reviewing the Inference with Discrete Latent Variables… =P

Anyways below is the code snippet using DiscreteHMCGibbs, but I think you could solve this with config_enumerate

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs

key = jax.random.PRNGKey(2)

guess = 0.7


def mystery(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    return murderer, weapon


conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})

nuts_kernel = NUTS(conditioned_model)

kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

mcmc = MCMC(kernel, num_warmup=200, num_samples=200, num_chains=4)
mcmc.run(key, guess)

mcmc.print_summary()
# print(f"\n{mcmc.get_samples()=}")

with numpyro.handlers.seed(rng_seed=0):
    samples = []
    for _ in range(1000):
        samples.append(
            tuple(
                [
                    sample.item() if hasattr(sample, "item") else sample
                    for sample in conditioned_model(guess)
                ]
            )
        )

samples = pd.DataFrame(samples, columns=["murderer", "weapon"])

print(pd.crosstab(samples.murderer, samples.weapon, normalize="all"))

Output:

sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:03<00:00, 127.08it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:00<00:00, 5765.11it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:00<00:00, 6251.29it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 400/400 [00:00<00:00, 6449.45it/s, 1 steps of size 1.19e+37. acc. prob=1.00]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  murderer      0.34      0.48      0.00      0.00      1.00   1949.83      1.00

weapon      0.0
murderer       
0         0.327
1         0.673
1 Like

You are right. Using infer_discrete (which is recommended) or DiscreteHMCGibbs is a solution here. The NUTS sampler marginalizes the discrete latent variables so there is no latent variable in your model.

I got the infer_discrete to work as well, but it was a little bit of a journey. I was stumped since the examples and discussion for infer_discrete were always about a model with continuous and discrete variables. Therefore the procedure was always (source):

  1. Have model p(data | discrete, continuous), marginalize and run MCMC to get p(continuous | data)
  2. Using infer_discrete to get samples from p(discrete | data, continuous)

Since the mystery model has only discrete variables, mcmc.get_samples() is an empty dictionary. It didn’t make a lot of sense to me what to do, but I just ended up just pushing through, and realized you don’t even have to run mcmc, can pass everything to infer_discrete with just an empty dictionary. This seems a little hacky…

Below is the minimum amount of code to run with infer_discrete using infer_discrete_model snippet from Example: Bayesian Models of Annotation:

import jax
import jax.numpy as jnp
import numpyro
from numpyro.contrib.funsor import config_enumerate, infer_discrete
import numpyro.distributions as dist

def infer_discrete_model(rng_key, samples):
    conditioned_model = numpyro.handlers.condition(model, data=samples)
    infer_discrete_model = infer_discrete(
        config_enumerate(conditioned_model), rng_key=rng_key
    )
    with numpyro.handlers.trace() as tr:
        infer_discrete_model(*data)

    return {
        name: site["value"]
        for name, site in tr.items()
        if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
    }

guess = 0.7

def model(guess, weapon):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]), obs=weapon)
    
data = (guess, 0.)

num_samples = 4000

discrete_samples = jax.vmap(infer_discrete_model)(
    jax.random.split(jax.random.PRNGKey(1), num_samples), {}
)

discrete_samples["murderer"].mean(), discrete_samples["murderer"].std()

Output:

(DeviceArray(0.353, dtype=float32), DeviceArray(0.47790274, dtype=float32))

Some references that helped me get something to work:

1 Like

Hi @bdatko, you are right, you don’t need to run MCMC if there is no latent variable (assume the discrete site is marginalized out) in your model. We just updated the annotation example to simplify the process. In your case, you can just do

predictive = Predictive(conditioned_model, num_samples=1000, infer_discrete=True)
samples = predictive(key, guess)

Ah, yes that’s a better interface.

How was recent was that addition? I am pip installing from Github and just now needed to re-build to see the change in Predictive.

Edit:
Apparently it was 2h ago

Thanks @fehiepsi!
Full working snippet:

import jax
import jax.numpy as jnp
import numpyro
from numpyro.infer.util import Predictive
import numpyro.distributions as dist

key = jax.random.PRNGKey(3)

guess = 0.7

def mystery(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    return murderer, weapon

conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})

predictive = Predictive(conditioned_model, num_samples=1000, infer_discrete=True)
samples = predictive(key, guess)
samples["murderer"].mean(), samples["murderer"].std()

Output:

(DeviceArray(0.356, dtype=float32), DeviceArray(0.47881523, dtype=float32))