It's very tricky to write a complex HMM-like model

I’m trying to implement an HMM-combined topic model. The brief structure of my model is like this:

# Components
motivation_plate = pyro.plate(..., dim=-1)
# Sample Component-conditioned RV
with motivation_plate:
    a = pyro.sample("a", ...)
    b = pyro.sample("b", ...)
# Sample state-topic-dependent RV
mu = torch.zeros((states_num, data_dims))
with motivation_plate:
    gamma = pyro.sample("gamma", dist.Normal(...).to_event(2))
# Sample batched transition probs
eta_s = 0.9 * torch.eye(self._hyper_params["S"]) + 0.1
with pyro.plate("users", self._data_dims["User"], dim=-1):
    p = pyro.sample("p", pyro_dist.Dirichlet(eta_s).to_event(1))
state = 0  # initial state
# Sample Obs data
# Batch plate
with pyro.plate("batched_users", self._data_dims["User"], self._args.batch_size, dim=-1) as batch:
    ...
    for t in pyro.markov(range(max_session_length)):
        with poutine.mask(mask=...):
            if isinstance(state, torch.Tensor) and state.dim() == 2:
                state = state.squeeze()
            p_batch_state = p[batch, state]
            state = pyro.sample(f"state_{t}", pyro_dist.Categorical(p_batch_state), infer={"enumerate": "parallel"})
            if state.dim() == 1:
                state = state[..., None]
            gamma_t = Vindex(gamma)[..., state, :]
            # Two different ways to implement the same variable computation for `non-enum` model and `enum` model
            if gamma_t.dim() == 3:
                # non enum model
                zeta_t = torch.bmm(gamma_t, x.unsqueeze(-1)).squeeze(-1).exp()  # (Batch, M)
            else:
                # enum model
                zeta_t = (gamma_b_t @ x.T).transpose(-1, -2).exp()  # `x` is obs variable
            # Sample obs data
            with pyro.plate("nested_plate", ..., dim=-2):
                ...

I think for users who are new to Pyro, it is very difficult to understand the dimensionality of Pyro, especially if there are discrete hidden variables in the model, the SVI algorithm requires the model to run with both enum and non-enum, and the corresponding dimensions cannot be inconsistent, which is probably one of the reasons why Pyro is very difficult to get started, because it took me 2 whole days to debug my model :frowning: .

My model is currently working fine, and I wrote the above to get your developer’s opinion on this type of issue. By the way, I am the author of PR #3195, and I am very happy that you have passed my PR. My current research work is focused on Bayesian modeling, and I hope to communicate with you more in the future.

1 Like