Hi all,
I am coding the example from the MBML book, chapter 1. I am expecting to have samples within my mcmc, and I donβt think there is an issue with my model definition (maybe?) since I can just sample the model and obtain the correct conditioning as well as the correct answer.
Am I making an obvious mistake?
# Min example of a mystery
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS
key = jax.random.PRNGKey(2)
guess = 0.7
def mystery(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    return murderer, weapon
conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})
nuts_kernel = NUTS(conditioned_model)
mcmc = MCMC(nuts_kernel, num_warmup=200, num_samples=200, num_chains=4)
mcmc.run(key, guess)
# mcmc.print_summary()
print(f"\n{mcmc.get_samples()=}")
with numpyro.handlers.seed(rng_seed=0):
    samples = []
    for _ in range(1000):
        samples.append(
            tuple(
                [
                    sample.item() if hasattr(sample, "item") else sample
                    for sample in conditioned_model(guess)
                ]
            )
        )
samples = pd.DataFrame(samples, columns=["murderer", "weapon"])
print(pd.crosstab(samples.murderer, samples.weapon, normalize="all"))
Output:
sample: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:01<00:00, 225.78it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:00<00:00, 7121.75it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:00<00:00, 6858.40it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:00<00:00, 7312.40it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
mcmc.get_samples()={}
weapon    0.0
murderer     
0         0.327
1         0.673
relevant environment information:
python --version = Python 3.9.5
numpyro                   0.6.0                    pypi_0    pypi
funsor                    0.4.1                    pypi_0    pypi
jax                       0.2.16                   pypi_0    pypi
jaxlib                    0.1.68                   pypi_0    pypi
Edit
I noticed I had a mistake with using the un-conditioned model mystery instead of the conditioned_model within NUTS, but running with either still results in an empty dictionary from mcmc.get_samples()