Obtaining discrete variable samples from trace

Given the model

def model():
    which = npy.sample("which", Bernoulli(0.5))
    with npy.plate("observations", len(data)):
        npy.sample(f"obs", Normal(which), obs=data)

Evaluated with:

nuts_kernel = NUTS(
    model,
    init_strategy=init_to_median,
    find_heuristic_step_size=True,
)
mcmc = MCMC(nuts_kernel, num_warmup=50, num_samples=25)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key)

When I obtain the samples samples = mcmc.get_samples() I get an empty dict.

Is there a way to obtain these samples?

Hi @GUIpsp, we are working on this issue. In the meantime, you can use DiscreteHMCGibbs.

1 Like