Feedback on Extending HMM example to Gaussian Likelihood

Hi,

I have written the following code, extending HMM example to a Gaussian Likelihood. One Gaussian per HMM state.

def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape

    # Our prior on transition probabilities will be:
    # stay in the same state with 50% probability; uniformly jump to another
    # state with 50% probability.
    probs_x = pyro.sample("probs_x",
                          dist.Dirichlet(0.5 * torch.eye(args.hidden_dim) + 0.5)
                              .to_event(1))
    # K-Gaussians, one for each state
    v0 = 1.  # scales variance of normal
    with pyro.plate("components", args.hidden_dim):
        sig_inv = pyro.sample("scales_y", dist.Gamma(2.0, 2.0).expand([data_dim]).to_event(1))
        mu = pyro.sample("locs_y", dist.Normal(0., torch.sqrt((1/sig_inv)*v0)).to_event(1))

    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate:
                    pyro.sample("y_{}".format(t), dist.Normal(
                        mu[x.squeeze(-1)], sig[x.squeeze(-1)]),
                                obs=sequences[batch, t])

It is not throwing any error. Do you think it is a reasonable implementation? or do you see any bugs/issues with the implementation.

I wanted to have a global normal-inverse gamma prior on the parameters of the Gaussians. Do you think it is reasonable to do it as mentioned above in the code or should I implement a separate stochastic function named NormalInverseGamma?

Thanks for the feedback