Parallelizing (CT)HMM

Hi, I’m trying to extend the HMM models from the examples to Continuous Time (CTHMM).

I just want to do it for positive integers time intervals so I think I only need to do the matrix power of the transition matrix (without needing to do matrix exponentials etc.).

So far I’ve managed to implement model_0:

def model_0(sequences, intervals, lengths, args, batch_size=None, include_prior=True):
  ...
  observations_plate = pyro.plate("observations", data_dim, dim=-1)
  for i in pyro.plate("sequences", len(sequences), batch_size):
      length = lengths[i]
      sequence = sequences[i, :length]
      for t in pyro.markov(range(length)):
          if t == 0:
            x = pyro.sample("x_{}_0".format(i), dist.Categorical(probs_pi), 
                            infer={"enumerate": "parallel"})
          else:
            px = probs_x.matrix_power(intervals[i][t - 1].int())
            x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(px[x]),
                            infer={"enumerate": "parallel"})
          with observations_plate:
              pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]),
                          obs=sequence[t])

The problem comes when I try model_1 (aka add parallalization):

def model_1(sequences, intervals, lengths, args, batch_size=None, include_prior=True):
    ...
    observations_plate = pyro.plate("observations", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
      lengths = lengths[batch]
      for t in pyro.markov(range(max_length if args.jit else lengths.max())):
        with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
          if t == 0:
            x =  pyro.sample("x_0", dist.Categorical(probs_pi), infer={"enumerate": "parallel"})
          else:
            px = torch.stack([
                probs_x.matrix_power(intervals[batch, t - 1][i].int())
                for i in range(batch_size)
            ])
            x = pyro.sample("x_{}".format(t), dist.Categorical(px[x]),
                            infer={"enumerate": "parallel"})

          with observations_plate:
              pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]),
                          obs=sequences[batch, t])

Which gives the error:

ValueError: Shape mismatch inside plate('sequences') at site y_1 dim -2, 8 vs 4

px has shape 8, 4, 4 but it should be 4, 4. It seems that torch.stack is not the way to compute this.

I would appreciate if somebody could hint how to compute the matrix_power for all the elements in the batch in parallel.

Thank you!

you may have other shape issues in your code but we have a batch matrix power utility here: Miscellaneous Ops — Pyro documentation

Thank you for answering

I replaced the torch.stack by:

max_batch_interval = intervals[batch, t - 1].max().int()
px = repeated_matmul(probs_x, max_batch_interval)[intervals[batch, t - 1].long() - 1]

But as you pointed there are shape issues still.
px has shape 8, 4, 4 (with a batch of size 8). Any idea of what should be done to make it work?

Thank you.