Inference in hmm example

Hello! could you please help clarify below?

We did the example: Example: Enumerate Hidden Markov Model — NumPyro documentation
and we have the following questions:

  1. How is the inference for the discrete random variable done in model_1?
    We are not clear on how the inference is done for x in model_1?
    To compare, we also tried to look at the Pyro code, but it is not clear how the enumeration is done in NumPyro? because it seems that plate with dims, and mask=True are used in NumPyro, but there is nothing that indicates enum. Could you please help clarify how is the inference done for discrete?

  2. model_1 runs very slow. We reduced the data to a very small size, but it still seems to be slow. We are wondering if the inference depends on the number of discrete variables? is that right? even when using scan code.

  3. Could you please point out the NumPyro source code and doc where the inference of the discrete variables is done? we could only find the documentation about the distribution, but not about the inference.

  4. How can we get the x samples in model_1?

Thank you!

In that example, we marginalize the discrete latent variable automatically. To avoid confusion, I guess in the future versions, we should remove that “automatic” behavior and require users to provide infer={"enumerate": "parallel"} explicitly. Please raise a github issue for this. :slight_smile:

For source code, I guess scan is what you need. In the docs, we quickly explain how parallel-scan is used.

About the speed, I guess the number of discrete variables is not quite important for those examples (anyway, there is only 1 discrete latent variable and we marginalized it out). Maybe the posterior is tricky to sample from or the number of computations is large (e.g. computing 100 x 100 x 100 log operators can be slow in CPU).

There are 2 issues with current implementations:

  • Though we use parallel-scan algorithm to compute log density, we still need to use jax.lax.scan to collect the distribution parameters. This affects performance in CPU. To get rid of this sequential (collect) computation to be completely parallel, we make this issue (it is doable but we didn’t have much time to look into it yet)
  • To infer x, there is a pending PR. The issues around it are a bit tricky to address…
1 Like

Thank you @fehiepsi! That is very helpful! :grin:
Ok, the request is now in github.

In gaussian_hmm from scan doc, there is no mask(mask=True). Why do we need it in model_1?

Also, what does mask(mask=(t<lengths)[…,None]) do? it is also not in gaussian_hmm. Why do we need it in model_1, and not in gaussian_hmm?

Each enumerated variable has its own enumeration dimension. Is it possible to annotate a loop with something similar to pyro.markov() in NumPyro?

We used mask there because those time series have different lengths. The gaussian_hmm example has only 1 single time series, so we don’t need the mask.

Is it possible to annotate a loop with something similar to pyro.markov() in NumPyro?

I think you can do it using contrib.funsor.markov but it would be slow. I would recommend using scan for time-series problems.

1 Like

Great, thank you @fehiepsi! that is so helpful.