Recover discrete latent states after enumerate, scan

You can use infer_discrete or using predictive = Predictive(model, posterior_samples, infer_discrete=True) like in the following tutorial: Example: Bayesian Models of Annotation — NumPyro documentation