Prediction with Hidden Markov Model (HMM) examples

Hi,

thank you for the great examples on how to perform inference in a hidden markov model with pyro (in pyro/examples/hmm.py)

I wonder what the best practice is to perform the following task:
Run the HMM on some input data for t steps.

  1. Infer the probabilities for the hidden states and observations for t+1
  2. Infer the probabilities for some of the observations conditioned on a given subset of observations.

Thank you and kind regards,
Jan

Hi @jan,
you can infer probabilities I believe using TraceEnum_ELBO.compute_marginals(). To condition on previous observations, you’ll probably either need to pass in data with some Nones into the model() or use poutine.condition to set those observations.

Note that Pyro’s dynamic programming is intended to be used on entire time series at once, typically inside an SVI or HMC training loop where we’re training over multiple time series. Pyro’s implementation does not save state while sequentially running through data, so it won’t be a great solution for e.g. predicting the next state in a control problem (you’d instead use Pyro to train an amortized guide, e.g. a neural net and use that neural net for prediction). We’re working on making it easier to do filtering-style prediction in Pyro, but it’s a long way from making it into a Pyro release.

1 Like

I had to use a similar routine. My task was to infer on a fresh test set with 21 samples, and I tried 2 functions

A low-level implementation of inference from a lifted model

...
        for _ in range(num_of_samples):
            guide_trace = pyro.poutine.trace(self.guide).get_trace(self.sequences)
            lifted_model = pyro.poutine.replay(self.model, guide_trace)
            pred_trace = pyro.poutine.trace(lifted_model).get_trace(self.sequences)
            state_seq_keys = filter(lambda x: affix_node_to_predict == x[:2], pred_trace.nodes.keys())
            preds.append([pred_trace.nodes[k]['value'].numpy() for k in state_seq_keys])

and using the new Predictive class

        trace_nodes = pyro.poutine.trace(self.model).nodes.keys()
        state_seq_keys = filter(lambda x: affix_node_to_predict == x[:2], trace_nodes)
        state_seq_keys = [affix_node_to_predict+str(k) for k in range(self.sequences['input'].shape[1])]
        pred_infer = Predictive(self.model, guide=self.guide, num_samples=num_of_samples,
                                return_sites=list(state_seq_keys))
        return pred_infer(self.sequences)

Let’s assume that num_of_samples=1. I’ve got similar results. Sampling all observation sites from my sequential model (y_0,...,y_986), which does not match the length of the sequences in the test set. For example, the prediction for a new 50 sized observation vector included all y_986 sites, which is wasteful in terms of runtime and inference. I guess that I need to trim by the length of the test sample, but the inference is simply slow…

Further, training with batch_size=15 yields an inferred vector/dict with dim -
[986, 15, num_samples=1].
It seems that 6 samples have vanished because of the predefined batch_size (test_size is 21). Can I recover them somehow?
Is there any other best practice to predict a sequence and test my model?

Hi @noam,

Can you first see if our new faster HMM inference algorithm works for you? Pyro 0.4.0 introduced a DiscreteHMM distribution, and 0.4.1 introduced a DiscreteHMM.filter() method for forecasting. The basic idea is to define a model with even a single DiscreteHMM sample site; fit it on data using SVI; use .filter() to predict final latent state; and finally sample future observations from that latent state.

If this doesn’t make sense, could you paste more details of your problem to help me understand what you’re trying to do?

Cheers,
Fritz

1 Like

Thanks @fritzo,
I will try the DiscreteHMM module as it supposes to boost up my calculations, however, I am not sure it is applicable to my model. I use the same model I’ve posted here. My current code is:

@config_enumerate
def model(self, sequences, include_prior=True):
    with ignore_jit_warnings():
        if isinstance(sequences, dict):
            input_seq = sequences["input"]
            output_seq = sequences["output"].squeeze()
        z = torch.Tensor([0]).type(torch.FloatTensor)
        y = torch.Tensor([0]).type(torch.FloatTensor)
    
    pyro.module("state_emitter", self.state_emitter)
    pyro.module("ar_emitter", self.ar_emitter)
    seq_plate = pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size)
    
    with poutine.mask(mask=include_prior):
        # transition matrix in the hidden state [ num_states X num_states ]
        transition_dist = dist.Dirichlet(
            0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1)
        probs_lat = pyro.sample("probs_lat", transition_dist)
    
    with seq_plate as batch:
        lengths = self.lengths[batch]
        input_batch = input_seq[batch, :]
    
        for t in pyro.markov(range(0, self.lengths.max() if self.args.jit else lengths.max())):
            t_mask = (t < lengths).type(torch.BoolTensor)
            with poutine.mask(mask=t_mask):
                # px.shape = [batch_size X num_states]
                px = self.state_emitter(input_batch[:, t, :].type(torch.FloatTensor), z)
                z_dist = dist.Categorical(Vindex(probs_lat)[..., px.argmax(dim=-1), :])
                z = pyro.sample(f"z_{t}", z_dist).type(torch.FloatTensor)
                assert t_mask.shape == z_dist.batch_shape[seq_plate.dim:]
    
                py = self.ar_emitter(y, z)  # px.shape = [batch_size X num_emission]
                obs_dist = dist.Categorical(py)
                # y.shape = [batch_size X 1]
                y = pyro.sample(f"y_{t}", obs_dist, obs=output_seq[batch, t]).type(torch.FloatTensor)

Now, I want to compare to different models by several metrics (NLL, Acc, etc.). I thought about predicting the observations on a test set. Does this change your suggestion?