HMM Tutorial and infer_discrete Forecasting Question

Hi,

I’m trying to get data predictions in an enumerated HMM model (based on the HMM tutorial), but the results returned appear to not be very accurate. So I’m not sure if there’s something wrong with the model itself or maybe with how I’m going about getting forecasts? The model is defined below.

def hmm_model_1(sequences, lengths, hidden_dim=16, predictive=False):
    # JIT boilerplate code.
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = sequences.shape

    # - Global Priors (transition matrix for latents z and emission matrix for data x).
    probs_z = pyro.sample('probs_z', dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
    probs_x = pyro.sample('probs_x', dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2))

    # - Generative process.
    tones_plate = pyro.plate('tones', data_dim, dim=-1)

    with pyro.plate('sequences', num_sequences, dim=-2):
        z = 0  # initial state is 0
        # Iterate through each time period in the sequence.
        for t in pyro.markov(range(max_length)):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                z = pyro.sample('z_{}'.format(t), dist.Categorical(probs_z[z]), infer={'enumerate': 'parallel'})
                # Iterate through each dimension of the observation for one time period.
                with tones_plate:
                    obs = pyro.sample('x_{}'.format(t), dist.Bernoulli(probs_x[z.squeeze(-1)]),
                                      obs=sequences[:, t, :] if not predictive else None)

The original tutorial seemed to train the data on 50 epochs by default, but I trained for 150 epochs, reaching an ELBO loss of ~570. Although I can tell it had not converged fully yet, I thought it was enough to test the model since it was 3 times more training than in the tutorial. I noticed the original paper referenced seems to get a log likelihood of about 8. Based on how the loss in my model was plateauing though, I don’t think I’d ever be able to reach close to that, not sure if there’s an issue there though since they may be normalizing their loss differently (such as by time periods instead of sequences).

My training loop is below, and mimics the HMM tutorial of getting MAP estimates for the transition and emission matrices. The JIT is just to speed up training, and enumeration marginalizes out the local latents z:

guide = AutoDelta(poutine.block(hmm_model_1, expose_fn=lambda msg: msg['name'].startswith('probs')))
optimizer = pyro.optim.Adam({'lr': 0.05})
elbo = JitTraceEnum_ELBO(max_plate_nesting=2)
svi = SVI(hmm_model_1, guide, optimizer, elbo)

# Training.
num_steps = 150
for step in range(num_steps):
    loss = svi.step(sequences, lengths)
    print(f"Step {step}, Loss: {loss / lengths.numel()}")

Then I use infer_discrete to try to get posterior predictive samples of the data for a few time periods so I can compare the with observed data and see if the trained model comes close. However, the predictions and the observed data don’t seem to align very well.

So I’m not sure if I’m doing something wrong when generating predictions and there’s maybe a better way to do it (perhaps my use of infer_discrete or the traces/replays I use below are incorrect), or if maybe something is wrong with the model itself and the predictions would be better if I got a lower ELBO loss. The code for generating my predictions is below.

seq = 10  # select a sequence
test_seq = test_sequences[seq].unsqueeze(0)  # sequences must be 3D, so add 1 for the number of observations
test_len = test_lengths[seq]  # with 33 length, the last index with data is test_seq[32]

x25_samples, x26_samples, x27_samples = [], [], []

with torch.no_grad():
    # Get MAP estimates of latents (e.g., transition and emission matrices).
    guide_trace = poutine.trace(guide).get_trace(sequences, lengths)
    # Condition model on MAP estimates of latents.
    trained_model = poutine.replay(hmm_model_1, trace=guide_trace)
    # Create model that exposes local variables that were marginalized out.
    serving_model = infer_discrete(trained_model, temperature=0, first_available_dim=-3)
    # Get model observation samples of certain time periods.
    for i in range(100):
        trace = poutine.trace(serving_model).get_trace(test_seq, test_len, predictive=True)
        x25_samples.append(trace.nodes['x_25']['value'])
        x26_samples.append(trace.nodes['x_26']['value'])
        x27_samples.append(trace.nodes['x_27']['value'])

# Convert list of tensors (i.e., samples) to one tensor (containing all samples).
x25_tensor = torch.stack(x25_samples, dim=0).squeeze()
x26_tensor = torch.stack(x26_samples, dim=0).squeeze()
x27_tensor = torch.stack(x27_samples, dim=0).squeeze()
# Get the percent of 1s in the sample. Note the observed data is a multi-dimensional vector of 1s and 0s (corresponding to piano chords for the polyphonic data).
x25_pred = x25_tensor.mean(dim=0)
x26_pred = x26_tensor.mean(dim=0)
x27_pred = x27_tensor.mean(dim=0)

Any insights or help around making predictions (filtered in-sample predictions and/or out-sample forecasts) with this or a general HMM would be greatly appreciated, thanks!