I got the infer_discrete
to work as well, but it was a little bit of a journey. I was stumped since the examples and discussion for infer_discrete
were always about a model with continuous and discrete variables. Therefore the procedure was always (source):
- Have model
p(data | discrete, continuous)
, marginalize and run MCMC to getp(continuous | data)
- Using
infer_discrete
to get samples fromp(discrete | data, continuous)
Since the mystery model has only discrete variables, mcmc.get_samples()
is an empty dictionary. It didn’t make a lot of sense to me what to do, but I just ended up just pushing through, and realized you don’t even have to run mcmc
, can pass everything to infer_discrete
with just an empty dictionary. This seems a little hacky…
Below is the minimum amount of code to run with infer_discrete
using infer_discrete_model
snippet from Example: Bayesian Models of Annotation:
import jax
import jax.numpy as jnp
import numpyro
from numpyro.contrib.funsor import config_enumerate, infer_discrete
import numpyro.distributions as dist
def infer_discrete_model(rng_key, samples):
conditioned_model = numpyro.handlers.condition(model, data=samples)
infer_discrete_model = infer_discrete(
config_enumerate(conditioned_model), rng_key=rng_key
)
with numpyro.handlers.trace() as tr:
infer_discrete_model(*data)
return {
name: site["value"]
for name, site in tr.items()
if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
}
guess = 0.7
def model(guess, weapon):
weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]), obs=weapon)
data = (guess, 0.)
num_samples = 4000
discrete_samples = jax.vmap(infer_discrete_model)(
jax.random.split(jax.random.PRNGKey(1), num_samples), {}
)
discrete_samples["murderer"].mean(), discrete_samples["murderer"].std()
Output:
(DeviceArray(0.353, dtype=float32), DeviceArray(0.47790274, dtype=float32))
Some references that helped me get something to work: