I am working on a factorial HMM (fHMM) and want to place a constraint so the diagonals of one transition matrix always has a larger diagonal mean than the other (i.e. there are more inter-state transitions in Z than X). I am currently using poutine.mask()
in a hacky way to accomplish this task but I am uncertain about how this works under the hood and whether there’s a more appropriate way to set constraints between two pyro parameters. Can I get some feedback on this choice?
My code is loosely based off model_3()
in the HMM tutorials
def model_3(sequences, args, batch_size=None, include_prior=True):
with ignore_jit_warnings():
num_sequences, max_length, data_dim = map(int, sequences.shape)
# establish priors
with poutine.mask(mask=include_prior):
# transition matrix priors for both hidden states
probs_x = pyro.sample("probs_x",
dist.Dirichlet(0.95 * torch.eye(args.hidden_dim) + 0.05)
.to_event(1))
probs_z = pyro.sample("probs_z",
dist.Dirichlet(0.1*torch.eye(2) + 0.9)
.to_event(1))
# scale terms for state X
probs_var_x = pyro.sample("probs_var_x",
dist.Beta(torch.ones(args.hidden_dim, dtype=torch.float) * 2.0,
torch.ones(args.hidden_dim, dtype=torch.float) * 3.0)
.to_event(1))
cn_plate = pyro.plate("cn", data_dim, dim=-1)
# CONSTRAINT ON TRANSITION DIAGONALS
with poutine.mask(mask=(torch.mean(torch.diagonal(probs_z))<torch.mean(torch.diagonal(probs_x)))):
with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
x = 0
z = 0
for t in pyro.markov(range(max_length)):
x = pyro.sample("x_{}".format(t), dist.Categorical(Vindex(probs_x)[x]),
infer={"enumerate": "parallel"})
z = pyro.sample("z_{}".format(t), dist.Categorical(Vindex(probs_z)[z]),
infer={"enumerate": "parallel"})
with cn_plate:
pyro.sample("y_{}".format(t), dist.Normal(x*(1+z), Vindex(probs_var_x)[x]),
obs=sequences[batch, t])