LinearHMM Reparameterization Reference


I was looking through the hidden Markov models and I came across the HiddenMarkovModel models. I am particularly interested in the LinearHMM model (as opposed to the GaussianHMM) because it has the additional transformation on the observations. My objective is to explore transformations like the Normalizing Flow and/or the (Variational) AutoEncoder methods as transformations on the observations where the latent space can be described by the GaussianHMM, e.g. see this paper and this paper respectively. It’s actually already written exactly within the pseudo-code:

z = initial_distribution.sample()
xs = []
for t in range(num_events):
    z = z @ transition_matrix + transition_dist.sample()
    x_transform = z @ observation_matrix + obs_base_dist.sample()
    x = obs_transform(x_transform)  # e.g. NF, VAE, SurVAE Flow

It is mentioned that the log_prob method isn’t implemented and I see here that is recommended to use the LinearHMMReparam with a TransformedDistribution. It seems super clever but I wanted to know if there was a reference paper on some of the mathematical details? I tried to gain insight from the code but I wasn’t really able to understand how it works.

The other option of course was to do things from scratch using the excellent Hidden Markov Model and Deep Markov Model tutorials as a reference. However, I wanted to check to see if perhaps it would would be an easy “plug-in-play” using the already implemented HMMs especially since there was a lot of work in optimising the GaussianHMM.

Hi @jejjohnson your example looks like an excellent topic for a tutorial, in case you’d like to contribute :slightly_smiling_face:

We have not yet published Pyro’s HMM reparametrization framework. The idea of LinearHmmReparam is to use effect handlers to transform a model involving a complex set of distributions to one involving simple distributions. I believe in your case, you could implement your normalizing flow as a Transform or TransformModule object (e.g. [AffineAutoregressive[(Distributions — Pyro documentation)), and then use TransformReparam. Something like this may work:

with poutine.reparam(config={"x": LinearHMMReparam(obs=TransformReparam())}):
                dist.Normal(torch.zeros(dim), 1).to_event(1),  # or another observation noise
                dist.transforms.AffineAutoregressive(...),  # or another normalizing flow

Feel free to put up a little tutorial PR for discussion!