Constraints between parameters in fHMM

I am working on a factorial HMM (fHMM) and want to place a constraint so the diagonals of one transition matrix always has a larger diagonal mean than the other (i.e. there are more inter-state transitions in Z than X). I am currently using poutine.mask() in a hacky way to accomplish this task but I am uncertain about how this works under the hood and whether there’s a more appropriate way to set constraints between two pyro parameters. Can I get some feedback on this choice?

My code is loosely based off model_3() in the HMM tutorials

def model_3(sequences, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
    # establish priors
    with poutine.mask(mask=include_prior):
        # transition matrix priors for both hidden states
        probs_x = pyro.sample("probs_x",
                              dist.Dirichlet(0.95 * torch.eye(args.hidden_dim) + 0.05)
                                  .to_event(1))
        probs_z = pyro.sample("probs_z", 
                              dist.Dirichlet(0.1*torch.eye(2) + 0.9)
                                  .to_event(1))

        # scale terms for state X
        probs_var_x = pyro.sample("probs_var_x",
                                  dist.Beta(torch.ones(args.hidden_dim, dtype=torch.float) * 2.0,
                                            torch.ones(args.hidden_dim, dtype=torch.float) * 3.0)
                                      .to_event(1))

    cn_plate = pyro.plate("cn", data_dim, dim=-1)
    # CONSTRAINT ON TRANSITION DIAGONALS
    with poutine.mask(mask=(torch.mean(torch.diagonal(probs_z))<torch.mean(torch.diagonal(probs_x)))):
        with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
            x = 0
            z = 0
            for t in pyro.markov(range(max_length)):
                x = pyro.sample("x_{}".format(t), dist.Categorical(Vindex(probs_x)[x]),
                                infer={"enumerate": "parallel"})
                z = pyro.sample("z_{}".format(t), dist.Categorical(Vindex(probs_z)[z]),
                                infer={"enumerate": "parallel"})
                with cn_plate:
                    pyro.sample("y_{}".format(t), dist.Normal(x*(1+z), Vindex(probs_var_x)[x]),
                                obs=sequences[batch, t])

Hi @aweiner,
I suspect the poutine.mask hack will either make inference very slow or make the variational posterior ill-defined. But I bet you could get this working by either softening the constraint or by transforming to an unconstrained space.

To soften the constraint, you could replace the poutine.mask statement with a factor statement like

temperature = 10.0  # for example
violation = probs_z.diagonal().mean() - probs_x.diagonal().mean()
pyro.factor("constraint", -torch.softplus(temperature * violation))

To transform to an unconstrained space you could follow the pattern of torch.distributions.transforms and torch.distributions.constraints libraries, which are extended by Pyro. In your case you could say force probs_x.diagonal.mean() to be greater than probs_z.diagonal().mean() by say adding some portion of probs_x to probs_z, say

# First sample unconstrained latent variables.
probs_x = pyro.sample(
    "probs_x",
    dist.Dirichlet(0.95 * torch.eye(args.hidden_dim) + 0.05).to_event(1),
)
probs_z_unconstrained = pyro.sample(
    "probs_z_unconstrained",
    dist.Dirichlet(0.1*torch.eye(2) + 0.9).to_event(1),
)

# Then transform probs_z.
violation = probs_z_unconstrained.diagonal().mean() - probs_x.diagonal().mean()
weight = torch.sigmoid(violation)
probs_z = pyro.deterministic(
    "probs_z",
    probs_x * weight + probs_z_unconstrained * (1 - weight),
)

(I think that should work :thinking:)

1 Like

Using a factor statement worked perfectly. Thank you!