Given the model
def model():
which = npy.sample("which", Bernoulli(0.5))
with npy.plate("observations", len(data)):
npy.sample(f"obs", Normal(which), obs=data)
Evaluated with:
nuts_kernel = NUTS(
model,
init_strategy=init_to_median,
find_heuristic_step_size=True,
)
mcmc = MCMC(nuts_kernel, num_warmup=50, num_samples=25)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key)
When I obtain the samples samples = mcmc.get_samples()
I get an empty dict.
Is there a way to obtain these samples?