HMM with Gaussian observations

Hey,
I followed the HMM examples from link and tried
to build a simple HMM with multi dimensional Gaussian observations. (adapted model_0, model_1 from the tutorial).
I used the following prior assumptions:
Transition table ← Dirichlet prior
Covariances ← LKJ prior
Mean ← Gaussian prior

The problem is that I get completely wrong results when comparing the true sequence with a Viterbi decoded sequence.
Are there any mistakes in the model definition?

Here is the full model code:

@config_enumerate
def hmm_model(sequence, args, include_prior=True):
    assert not torch._C._get_tracing_state()
    data_dim = sequence.shape[1]
    seq_length = sequence.shape[0]

    #  If mask=False log_probs of samples are set to 0 inside mask context
    # ==> uses no sampled prior -> cancelled out due to setting to 0
    with poutine.mask(mask=include_prior):
        # Sample from transition prior
        transition_probs = pyro.sample("transition_probs", dist.Dirichlet(torch.ones(args.num_states, args.num_states)
                                                                          .fill_(0.1)).to_event(1))

        # Sample from covariance prior
        theta_probs = pyro.sample("theta_probs", dist.HalfCauchy(torch.ones(data_dim).expand([args.num_states, data_dim]))
                                  .to_event(2))
        eta = torch.ones(1)
        L_omega_probs = pyro.sample("L_omega_probs", dist.LKJCorrCholesky(data_dim, eta).expand([args.num_states]).to_event(1))
        
        cov_probs = torch.bmm(theta_probs.sqrt().diag_embed(), L_omega_probs)
        

        # Sample from mean prior
        mean_probs = pyro.sample('mean_probs', dist.MultivariateNormal(torch.zeros(args.num_states, data_dim),
                                                                       torch.eye(data_dim).expand([args.num_states, data_dim])).to_event(1))


    # Marginalize states and condition on observations
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    z = 0
    for t in pyro.markov(range(seq_length)):
        z = pyro.sample("z_{}".format(t), dist.Categorical(transition_probs[z]),
                        infer={"enumerate": "parallel"})
        with tones_plate:
            pyro.sample("x_{}".format(t), dist.MultivariateNormal(mean_probs[z], scale_tril=cov_probs[z]),
                        obs=sequence[t])

how are you doing inference? for example are you doing MAP w.r.t. transition_probs?

Yes, I’m doing MAP inference w.r.t transition_probs.

Here is the viterbi code I use. This is also inspired by an example

@infer_discrete(first_available_dim=-1, temperature=0)
@config_enumerate
def viterbi_decoder(data, transition_probs, means, covs):
    states = [0]
    for t in pyro.markov(range(len(data))):
        states.append(pyro.sample("states_{}".format(t),
                                  dist.Categorical(transition_probs[states[-1]])))
        pyro.sample("obs_{}".format(t),
                    dist.MultivariateNormal(means[states[-1]], scale_tril=covs[states[-1]]),
                    obs=data[t])
    return states  # returns maximum likelihood states

I fixed the problem of getting completely wrong results. That was related to choosing the wrong elbo function. With TraceEnum_ELBO the results are sometimes correct.
If not, then the states are kind of exchanged.

For example:
True sequence: 000022222221100
Inferred sequence: 111122222220011 or 0000111111112200

Maybe this could be related to the parameters that regulate the priors?

why do you expect to infer the “true” states exactly? isn’t “within a permutation” all you might expect?

i expected the indices of the mean and the covariance tensors to correspond to the “true” state number.
does pyro shuffle the order of the indices or is there another thing i didn’t get right now?

doesn’t your model include a permutation symmetry among the hidden state labels?

Ah now I got it. Forgot to handle this case.
Thank you!