# Constraints between parameters in fHMM

I am working on a factorial HMM (fHMM) and want to place a constraint so the diagonals of one transition matrix always has a larger diagonal mean than the other (i.e. there are more inter-state transitions in Z than X). I am currently using `poutine.mask()` in a hacky way to accomplish this task but I am uncertain about how this works under the hood and whether there’s a more appropriate way to set constraints between two pyro parameters. Can I get some feedback on this choice?

My code is loosely based off `model_3()` in the HMM tutorials

``````def model_3(sequences, args, batch_size=None, include_prior=True):
with ignore_jit_warnings():
num_sequences, max_length, data_dim = map(int, sequences.shape)
# establish priors
# transition matrix priors for both hidden states
probs_x = pyro.sample("probs_x",
dist.Dirichlet(0.95 * torch.eye(args.hidden_dim) + 0.05)
.to_event(1))
probs_z = pyro.sample("probs_z",
dist.Dirichlet(0.1*torch.eye(2) + 0.9)
.to_event(1))

# scale terms for state X
probs_var_x = pyro.sample("probs_var_x",
dist.Beta(torch.ones(args.hidden_dim, dtype=torch.float) * 2.0,
torch.ones(args.hidden_dim, dtype=torch.float) * 3.0)
.to_event(1))

cn_plate = pyro.plate("cn", data_dim, dim=-1)
# CONSTRAINT ON TRANSITION DIAGONALS
with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
x = 0
z = 0
for t in pyro.markov(range(max_length)):
x = pyro.sample("x_{}".format(t), dist.Categorical(Vindex(probs_x)[x]),
infer={"enumerate": "parallel"})
z = pyro.sample("z_{}".format(t), dist.Categorical(Vindex(probs_z)[z]),
infer={"enumerate": "parallel"})
with cn_plate:
pyro.sample("y_{}".format(t), dist.Normal(x*(1+z), Vindex(probs_var_x)[x]),
obs=sequences[batch, t])
``````

Hi @aweiner,
I suspect the `poutine.mask` hack will either make inference very slow or make the variational posterior ill-defined. But I bet you could get this working by either softening the constraint or by transforming to an unconstrained space.

To soften the constraint, you could replace the `poutine.mask` statement with a factor statement like

``````temperature = 10.0  # for example
violation = probs_z.diagonal().mean() - probs_x.diagonal().mean()
pyro.factor("constraint", -torch.softplus(temperature * violation))
``````

To transform to an unconstrained space you could follow the pattern of torch.distributions.transforms and torch.distributions.constraints libraries, which are extended by Pyro. In your case you could say force `probs_x.diagonal.mean()` to be greater than `probs_z.diagonal().mean()` by say adding some portion of `probs_x` to `probs_z`, say

``````# First sample unconstrained latent variables.
probs_x = pyro.sample(
"probs_x",
dist.Dirichlet(0.95 * torch.eye(args.hidden_dim) + 0.05).to_event(1),
)
probs_z_unconstrained = pyro.sample(
"probs_z_unconstrained",
dist.Dirichlet(0.1*torch.eye(2) + 0.9).to_event(1),
)

# Then transform probs_z.
violation = probs_z_unconstrained.diagonal().mean() - probs_x.diagonal().mean()
weight = torch.sigmoid(violation)
probs_z = pyro.deterministic(
"probs_z",
probs_x * weight + probs_z_unconstrained * (1 - weight),
)
``````

(I think that should work )

1 Like

Using a factor statement worked perfectly. Thank you!