Thanks @fritzo,
I will try the DiscreteHMM
module as it supposes to boost up my calculations, however, I am not sure it is applicable to my model. I use the same model I’ve posted here. My current code is:
@config_enumerate
def model(self, sequences, include_prior=True):
with ignore_jit_warnings():
if isinstance(sequences, dict):
input_seq = sequences["input"]
output_seq = sequences["output"].squeeze()
z = torch.Tensor([0]).type(torch.FloatTensor)
y = torch.Tensor([0]).type(torch.FloatTensor)
pyro.module("state_emitter", self.state_emitter)
pyro.module("ar_emitter", self.ar_emitter)
seq_plate = pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size)
with poutine.mask(mask=include_prior):
# transition matrix in the hidden state [ num_states X num_states ]
transition_dist = dist.Dirichlet(
0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1)
probs_lat = pyro.sample("probs_lat", transition_dist)
with seq_plate as batch:
lengths = self.lengths[batch]
input_batch = input_seq[batch, :]
for t in pyro.markov(range(0, self.lengths.max() if self.args.jit else lengths.max())):
t_mask = (t < lengths).type(torch.BoolTensor)
with poutine.mask(mask=t_mask):
# px.shape = [batch_size X num_states]
px = self.state_emitter(input_batch[:, t, :].type(torch.FloatTensor), z)
z_dist = dist.Categorical(Vindex(probs_lat)[..., px.argmax(dim=-1), :])
z = pyro.sample(f"z_{t}", z_dist).type(torch.FloatTensor)
assert t_mask.shape == z_dist.batch_shape[seq_plate.dim:]
py = self.ar_emitter(y, z) # px.shape = [batch_size X num_emission]
obs_dist = dist.Categorical(py)
# y.shape = [batch_size X 1]
y = pyro.sample(f"y_{t}", obs_dist, obs=output_seq[batch, t]).type(torch.FloatTensor)
Now, I want to compare to different models by several metrics (NLL, Acc, etc.). I thought about predicting the observations on a test set. Does this change your suggestion?