Inference with DiscreteHMM

  1. Does Pyro have a function to get the most likely sequence of hidden states (Viterbi) for the model containing DiscreteHMM?

  2. Is there a way to implement a Discrete HMM model in Pyro where multiple random variables depend the same discrete hidden state?

Thanks

Updates:

I was able to do the Viterbi-like MAP inference. I fit the model using DiscreteHMM distribution and then apply infer_discrete with temperature=0 to identical model but written using pyro.markov primitive. Is this the intended way of doing discrete inference?

1 Like

Hi @ordabayev, yes, I like your solution of training with a DiscreteHMM then writing an equivalent model and using infer_discrete for Viterbi inference.

Is there a way to implement a Discrete HMM model in Pyro where multiple random variables depend the same discrete hidden state?

Yes there are a few options:

  1. You can use DiscreteHMM with random variables that can be batched into a single random variable. For example if two downstream Normal variables depend on the latent state, you can concat them into a single distribution:
    obs_dist = Normal(torch.stack([loc1, loc2], dim=-1),
                      torch.stack([scale1, scale2], dim=-1)).to_event(1)
    DiscreteHMM(..., obs_dist)
    
  2. If your distributions cannot be concatenated, you can use the pyro.markov and the non-vectorized encoding of random variables in Pyro:
    for t in pyro.markov(range(T)):
        z = pyro.sample("z_{}".format(t), ...,
                        infer={"enumerate": "parallel"})
        pyro.sample("x1_{}".format(t), dist.Normal(locs[z], 1),
                    obs=x1[t])
        pyro.sample("x1_{}".format(t), dist.Bernoulli(probs[z], 1),
                    obs=x2[t])
    
  3. If you really want to use DiscreteHMM for heterogeneous observations, you could manually construct a Categorical observation distribution whose logits are the sum of log-likelihoods from other distributions. It will be difficult to sample from this (e.g. for forecasting), but if all you want is MAP latent state (e.g. for segmentation), then this will be fast and general.

Good luck, and don’t hesitate to paste successful code snippets here, I’m sure others would benefit!
@fritzo