Hello! could you please help clarify below?
We did the example: Example: Enumerate Hidden Markov Model — NumPyro documentation
and we have the following questions:
How is the inference for the discrete random variable done in model_1?
We are not clear on how the inference is done for x in model_1?
To compare, we also tried to look at the Pyro code, but it is not clear how the enumeration is done in NumPyro? because it seems that plate with dims, and mask=True are used in NumPyro, but there is nothing that indicates enum. Could you please help clarify how is the inference done for discrete?
model_1 runs very slow. We reduced the data to a very small size, but it still seems to be slow. We are wondering if the inference depends on the number of discrete variables? is that right? even when using scan code.
Could you please point out the NumPyro source code and doc where the inference of the discrete variables is done? we could only find the documentation about the distribution, but not about the inference.
How can we get the x samples in model_1?