Infer discrete latent state of a hidden markov model

I am attempting to recreate the Dishonest Casino example from Kevin P. Murphy’s book “Advanced Probabilistic Machine Learning”. The authors offer an example notebook and solution in their new library Dynamax. However, I believe this can be recreated (and could be a nice example notebook) using Pyro’s discreteHMM distribution and use of the infer_discrete function. I am having trouble inferring the latent state for each observation (emission). The parameters for the HMM are:

initial_probs = torch.tensor([0.5, 0.5])
transition_matrix = torch.tensor([[0.95, 0.05], 
                               [0.10, 0.90]])
emission_probs = torch.tensor([[1/6,  1/6,  1/6,  1/6,  1/6,  1/6],    # fair die
                            [1/10, 1/10, 1/10, 1/10, 1/10, 5/10]])  # loaded die

Using the parameterised discreteHMM distribution, we can generate samples:

hmm = dist.DiscreteHMM(
    initial_logits=torch.logit(initial_probs),
    transition_logits=torch.logit(transition_matrix),
    observation_dist=dist.Categorical(emission_probs),
    duration=50
)

hmm.sample()
tensor([3, 2, 5, 5, 3, 5, 5, 5, 5, 5, 5, 3, 5, 0, 5, 5, 1, 5, 5, 1, 0, 1, 3, 1,
        5, 5, 1, 1, 5, 1, 5, 0, 5, 5, 5, 5, 5, 3, 5, 5, 2, 1, 3, 5, 5, 5, 5, 0,
        5, 2])

Each emission was generated according to some latent state (fair or loaded dice). Initially, I figured I could use the .filter() of the discreteHMM distribution to obtain the posterior of the latent state given a sequence of observations.

emissions = hmm.sample()
post_states = hmm.filter(emissions)

Here, the emissions are the sequence of observations generated by the HMM model. The post_states is a Categorical distribution. However, the .filter() method eliminates the time dimension, and therefore, post_states is the posterior over the final state given a sequence of observations (it does not return the most likely state that gave rise to each observation in the sequence). I would like to have the most probable state for each observation in the sequence using an inference technique like forward filtering or Viterbi (MAP).

This led me to discover the infer_discrete() function with the parameter temperature. When temperature is set to 1 this corresponds to forward filtering, and 0 corresponds to Viterbi-like MAP inference. Additionally, I read through the Inference with Discrete Latent Variables tutorial. Attempting to use infer_discrete standalone, I should be able to obtain MAP estimates for the latent state given a sequence of observations.

def model():

    emissions = pyro.sample("emissions", dist.DiscreteHMM(
        initial_logits=torch.logit(initial_probs),
        transition_logits=torch.logit(transition_matrix),
        observation_dist=dist.Categorical(emission_probs),
        duration=50
    ), infer={"enumerate": "sequential"})

    return emissions

serve = infer_discrete(model, first_available_dim=-1, temperature=0)
serve()
tensor([5, 5, 5, 3, 1, 2, 4, 1, 2, 0, 5, 5, 4, 5, 5, 5, 5, 2, 5, 5, 0, 5, 5, 5,
        1, 1, 5, 5, 1, 1, 5, 1, 5, 5, 2, 5, 5, 5, 5, 4, 5, 5, 4, 5, 3, 5, 1, 5,
        4, 3])

However, this only returns samples generated by the HMM model. Is it possible to use infer_discrete standalone to obtain MAP estimates of the latent states? Or will I need to define a model with priors and perform training in order to achieve this objective? I would like to have the posterior probability of each latent state that gave rise to the observation. In the tutorial linked above, there is a time series example, however, they do not use the discreteHMM distribution.

Thanks for your time and help. Cheers!

Hi @GStechschulte, great question! Pyro’s DiscreteHMM distribution currently does not implement a .sample_posterior() method, so there is no direct way to get samples or Viterbi solution conditioned on observations. However you can use Pyro’s infer_discrete with an explicit HMM model, as in the hmm example, something like this:

def model_2(xs):
    z = pyro.sample("z_init", dist.Categorical(initial_probs))
    for t in pyro.markov(len(xs)):
        z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(transition_matrix)[..., z, :]))
        pyro.sample(f"x_{t}", dist.Categorical(Vindex(emmision_probs)[..., z, :]), obs=xs[t])

Then you can use infer_discrete to sample or optimize a sequence of z_* values and stack them together. Training these explicit models is slower than training a DiscreteHMM, but since this particular model is equivalent, you can train the DiscreteHMM and then use an explicit model for sampling.

1 Like

Thanks for the quick response on this. I am a bit busy at the moment, but will incorporate your suggestions and loop back to this post with my results. Thanks!

Defining an explicit HMM model with @config_enumerate (b/c we want to interpret discrete pyro.sample statements as full enumeration rather than random sampling.) results in the following model:

@config_enumerate
def model_2(xs):
    # sample an inital latent state
    z = pyro.sample('z_init', dist.Categorical(initial_probs))
    states = []
    for t, y in pyro.markov(enumerate(xs)):
        # we don't need to add {enumerate: ...} b/c the 
        # @config_enumerate decorator detects discrete dist.
        z = pyro.sample(
            f'z_{t}', 
            dist.Categorical(Vindex(transition_matrix)[..., z, :]),
            )
        states.append(z)
        pyro.sample(
            f"obs_{t}", 
            dist.Categorical(Vindex(emission_probs)[..., z, :]),
            obs=y
            )

    return states

By returning states, I can sample a sequence (depending on the length of xs) of z values when using the infer_discrete function. If I read the tutorial Inference with Discrete Variables right, infer_discrete can be used standalone to obtain samples and MAP estimates of the latent state without embedding inside SVI (via TraceEnum_ELBO), HMC, or NUTS.

However, if I would like to perform training, I would define a guide with no exposed parameters because I don’t care about learning any of the global parameters (e.g., transition and emission probs), and hide the discrete latent variable z. I would then choose to use TraceEnum_ELBO

hmm_guide = AutoDelta(
    pyro.poutine.block(
        model_2,
        hide_fn=lambda msg: msg["name"].startswith("z_"))
    )

optimizer = pyro.optim.Adam({'lr': 0.03})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model_2, hmm_guide, optimizer, loss=elbo)
#elbo.loss(model_2, hmm_guide, emission)

pyro.clear_param_store()

losses = []
for step in range(501):
    loss = svi.step(emission)
    losses.append(loss)

Please note, the rest of this example does not use the training block above to perform inference. Rather, model_2 is defined, and infer_discrete is used to generate latent state samples conditioned on observations. I do this in a for loop to generate z_* samples.

post_samples = torch.Tensor(
    [infer_discrete(model_2, first_available_dim=-1, temperature=1)(emission) for _ in range(100)]
    )
# sampled latent states
post_samples[0]
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Side note: If I had performed training, I could then feed this into the Predictive utility to obtain posterior predictive samples for the latent state of each z_* and emission y_* in the sequence.

posterior_samples = {"z": posterior_samples}
post_pred = Predictive(model=model_2, posterior_samples=posterior_samples)(emission)

I can then loop through post_samples and calculate the probability of a loaded die at time t.

cnt = 0
samps = {}
for batch in post_samples:
    cnt += 1
    for t, y in enumerate(batch):
        if cnt == 1:
            samps[t] = [y]
        else:
            samps[t].append(y)

probs_loaded = []
for key, val in samps.items():
    probs_loaded.append(sum(val) / len(val))

plt.figure(figsize=(10, 2))
plt.plot(
    torch.arange(0, 101),
    probs_loaded
);

and the infer_discrete using the forward-filter method to generate posterior latent state samples seems to do a reasonable job at identifying the latent states.

one_hot = torch.nn.functional.one_hot(emission, 6).T

plt.figure(figsize=(10, 2))
plt.imshow(one_hot, aspect="auto", interpolation="none", cmap="Greys")
plt.imshow(
    posterior_samples['z'][None, :], extent=(0, 100, 6-.5, -.5), 
    interpolation="none", aspect="auto", cmap="Greys", alpha=0.25
    )
plt.xlabel("time")
plt.ylabel("emission")
plt.yticks(np.arange(6), np.arange(6) + 1)
plt.title("sampled sequence (white=fair, gray=loaded)");

Looking at the top plot P(\text{loaded die}) and the bottom plot of which latent state was sampled, the results indeed seem reasonable.