Hey,
I followed the HMM examples from link and tried
to build a simple HMM with multi dimensional Gaussian observations. (adapted model_0, model_1 from the tutorial).
I used the following prior assumptions:
Transition table ← Dirichlet prior
Covariances ← LKJ prior
Mean ← Gaussian prior
The problem is that I get completely wrong results when comparing the true sequence with a Viterbi decoded sequence.
Are there any mistakes in the model definition?
Here is the full model code:
@config_enumerate
def hmm_model(sequence, args, include_prior=True):
assert not torch._C._get_tracing_state()
data_dim = sequence.shape[1]
seq_length = sequence.shape[0]
# If mask=False log_probs of samples are set to 0 inside mask context
# ==> uses no sampled prior -> cancelled out due to setting to 0
with poutine.mask(mask=include_prior):
# Sample from transition prior
transition_probs = pyro.sample("transition_probs", dist.Dirichlet(torch.ones(args.num_states, args.num_states)
.fill_(0.1)).to_event(1))
# Sample from covariance prior
theta_probs = pyro.sample("theta_probs", dist.HalfCauchy(torch.ones(data_dim).expand([args.num_states, data_dim]))
.to_event(2))
eta = torch.ones(1)
L_omega_probs = pyro.sample("L_omega_probs", dist.LKJCorrCholesky(data_dim, eta).expand([args.num_states]).to_event(1))
cov_probs = torch.bmm(theta_probs.sqrt().diag_embed(), L_omega_probs)
# Sample from mean prior
mean_probs = pyro.sample('mean_probs', dist.MultivariateNormal(torch.zeros(args.num_states, data_dim),
torch.eye(data_dim).expand([args.num_states, data_dim])).to_event(1))
# Marginalize states and condition on observations
tones_plate = pyro.plate("tones", data_dim, dim=-1)
z = 0
for t in pyro.markov(range(seq_length)):
z = pyro.sample("z_{}".format(t), dist.Categorical(transition_probs[z]),
infer={"enumerate": "parallel"})
with tones_plate:
pyro.sample("x_{}".format(t), dist.MultivariateNormal(mean_probs[z], scale_tril=cov_probs[z]),
obs=sequence[t])