Hi, I’m trying to extend the HMM models from the examples to Continuous Time (CTHMM).
I just want to do it for positive integers time intervals so I think I only need to do the matrix power of the transition matrix (without needing to do matrix exponentials etc.).
So far I’ve managed to implement model_0:
def model_0(sequences, intervals, lengths, args, batch_size=None, include_prior=True):
...
observations_plate = pyro.plate("observations", data_dim, dim=-1)
for i in pyro.plate("sequences", len(sequences), batch_size):
length = lengths[i]
sequence = sequences[i, :length]
for t in pyro.markov(range(length)):
if t == 0:
x = pyro.sample("x_{}_0".format(i), dist.Categorical(probs_pi),
infer={"enumerate": "parallel"})
else:
px = probs_x.matrix_power(intervals[i][t - 1].int())
x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(px[x]),
infer={"enumerate": "parallel"})
with observations_plate:
pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]),
obs=sequence[t])
The problem comes when I try model_1 (aka add parallalization):
def model_1(sequences, intervals, lengths, args, batch_size=None, include_prior=True):
...
observations_plate = pyro.plate("observations", data_dim, dim=-1)
with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
lengths = lengths[batch]
for t in pyro.markov(range(max_length if args.jit else lengths.max())):
with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
if t == 0:
x = pyro.sample("x_0", dist.Categorical(probs_pi), infer={"enumerate": "parallel"})
else:
px = torch.stack([
probs_x.matrix_power(intervals[batch, t - 1][i].int())
for i in range(batch_size)
])
x = pyro.sample("x_{}".format(t), dist.Categorical(px[x]),
infer={"enumerate": "parallel"})
with observations_plate:
pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]),
obs=sequences[batch, t])
Which gives the error:
ValueError: Shape mismatch inside plate('sequences') at site y_1 dim -2, 8 vs 4
px has shape 8, 4, 4 but it should be 4, 4. It seems that torch.stack is not the way to compute this.
I would appreciate if somebody could hint how to compute the matrix_power for all the elements in the batch in parallel.
Thank you!