Hi all,
I am coding the example from the MBML book, chapter 1. I am expecting to have samples within my mcmc
, and I donβt think there is an issue with my model definition (maybe?) since I can just sample the model and obtain the correct conditioning as well as the correct answer.
Am I making an obvious mistake?
# Min example of a mystery
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS
key = jax.random.PRNGKey(2)
guess = 0.7
def mystery(guess):
weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
return murderer, weapon
conditioned_model = numpyro.handlers.condition(mystery, {"weapon": 0.0})
nuts_kernel = NUTS(conditioned_model)
mcmc = MCMC(nuts_kernel, num_warmup=200, num_samples=200, num_chains=4)
mcmc.run(key, guess)
# mcmc.print_summary()
print(f"\n{mcmc.get_samples()=}")
with numpyro.handlers.seed(rng_seed=0):
samples = []
for _ in range(1000):
samples.append(
tuple(
[
sample.item() if hasattr(sample, "item") else sample
for sample in conditioned_model(guess)
]
)
)
samples = pd.DataFrame(samples, columns=["murderer", "weapon"])
print(pd.crosstab(samples.murderer, samples.weapon, normalize="all"))
Output:
sample: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:01<00:00, 225.78it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:00<00:00, 7121.75it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:00<00:00, 6858.40it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
sample: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 400/400 [00:00<00:00, 7312.40it/s, 1 steps of size 1.19e+37. acc. prob=1.00]
mcmc.get_samples()={}
weapon 0.0
murderer
0 0.327
1 0.673
relevant environment information:
python --version = Python 3.9.5
numpyro 0.6.0 pypi_0 pypi
funsor 0.4.1 pypi_0 pypi
jax 0.2.16 pypi_0 pypi
jaxlib 0.1.68 pypi_0 pypi
Edit
I noticed I had a mistake with using the un-conditioned model mystery
instead of the conditioned_model
within NUTS
, but running with either still results in an empty dictionary from mcmc.get_samples()