Hi all,
I’m trying to fit a special hidden markov model to some sequential data similar to the hmm examples page. In addition, I would like to:

for some of my observations, i have partial labels. Meaning for some observations, i partially know the set of states that generated these observations. For example, sequences[0][9] was drawn from either state 0 or state 2. Is there a way i can incorporate this information in the model?

impose constraints on the transition matrix by not allowing certain state transitions to occur. For example I would like the transition matrix to have a specific structure. In the code below i tried to do this by manually setting certain entries of the transition matrix log probs to inf (or a small value). This seems to do the trick, but i was wondering if there is a cleaner way to do it?
def hmm(sequences, lengths, args, batch_size=None, include_prior=True):
num_sequences, max_length, data_dim = map(int, sequences.shape)
with poutine.mask(mask=include_prior):
prior_probs = pyro.sample("prior_probs", dist.Dirichlet(0.9 * torch.ones(args.hidden_dim) + 0.1))
transition_probs = pyro.sample("transition_probs", dist.Dirichlet(9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
emission_probs = pyro.sample("emission_probs", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2))
with pyro.plate("sequences", num_sequences, batch_size, dim=1) as batch:
lengths = lengths[batch]
y = sequences[batch, : lengths.max()]
init_logits = prior_probs.log()
trans_logits = transition_probs.log()
trans_logits[1, 0] = 1e2 #torch.inf
trans_logits[2, 0] = 1e2 #torch.inf
trans_logits[3, 0] = 1e2#torch.inf
obs_dist = dist.Bernoulli(emission_probs).to_event(1)
pyro.sample("y", dist.DiscreteHMM(init_logits, trans_logits, obs_dist), obs=y)