Problem with batching a neural DBN model

Hi,
While trying to solve a problem regarding the enumeration of my neural DBN model I met a different kind of problem with the batching mechanism. It is a batching problem since it doesn’t work in subsample_size > 1 settings.

Specifically, I am going over a time-series input and have a module for predicting my current state, given 2 entries: the last state (the enumerated variable z) and some input observation that I go over in batch mode (input_batch[:, t, :]):

def model():
        ...

        with pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size, dim=-2) as batch:
            lengths = self.lengths[batch]
            z = torch.tensor(0, dtype=torch.long)
            y = torch.zeros(self.args.batch_size, 1)
            input_batch = input_seq[batch, :]

            for t in pyro.markov(range(0, self.lengths.max())):
                with poutine.mask(mask=(t < lengths).unsqueeze(-1)):

                    z_current = self.state_emitter(input_batch[:, t, :], z).argmax(dim=1) # expected shape [batch_size X 1]
                    z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[...,z_current,:]))  # expected shape [batch_size X 1]
        ...

The problem is that I do not need to double batch over z, but I do want to use z_current which is per-sample . I need z to be have single state for each sample in the batch s.t. z.shape=torch.Size([8,1]) , however I get z.shape = torch.Size([8,8]). How would you suggest to change the architecture or what pyro handler would you use, in order to disable this double batching effect?

BTW I would really appreciate some input regarding my other post

Hi, this might be caused by setting dim=-2 in your plate call, meaning that the plate dimension does not coincide with the batch shape of z. Is 8 your subsample_size? What happens when you set dim=-1?