Hi,
While trying to solve a problem regarding the enumeration of my neural DBN model I met a different kind of problem with the batching mechanism. It is a batching problem since it doesn’t work in subsample_size > 1
settings.
Specifically, I am going over a time-series input and have a module for predicting my current state, given 2 entries: the last state (the enumerated variable z
) and some input observation that I go over in batch mode (input_batch[:, t, :]
):
def model():
...
with pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size, dim=-2) as batch:
lengths = self.lengths[batch]
z = torch.tensor(0, dtype=torch.long)
y = torch.zeros(self.args.batch_size, 1)
input_batch = input_seq[batch, :]
for t in pyro.markov(range(0, self.lengths.max())):
with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
z_current = self.state_emitter(input_batch[:, t, :], z).argmax(dim=1) # expected shape [batch_size X 1]
z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[...,z_current,:])) # expected shape [batch_size X 1]
...
The problem is that I do not need to double batch over z
, but I do want to use z_current
which is per-sample . I need z
to be have single state for each sample in the batch s.t. z.shape=torch.Size([8,1])
, however I get z.shape = torch.Size([8,8])
. How would you suggest to change the architecture or what pyro handler would you use, in order to disable this double batching effect?
BTW I would really appreciate some input regarding my other post