Hi,

While trying to solve a problem regarding the enumeration of my neural DBN model I met a different kind of problem with the batching mechanism. It is a batching problem since it doesn’t work in `subsample_size > 1`

settings.

Specifically, I am going over a time-series input and have a module for predicting my current state, given 2 entries: the last state (the enumerated variable `z`

) and some input observation that I go over in batch mode (`input_batch[:, t, :]`

):

```
def model():
...
with pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size, dim=-2) as batch:
lengths = self.lengths[batch]
z = torch.tensor(0, dtype=torch.long)
y = torch.zeros(self.args.batch_size, 1)
input_batch = input_seq[batch, :]
for t in pyro.markov(range(0, self.lengths.max())):
with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
z_current = self.state_emitter(input_batch[:, t, :], z).argmax(dim=1) # expected shape [batch_size X 1]
z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[...,z_current,:])) # expected shape [batch_size X 1]
...
```

The problem is that I do not need to double batch over `z`

, but I do want to use `z_current`

which is per-sample . I need `z`

to be have single state for each sample in the batch s.t. `z.shape=torch.Size([8,1])`

, however I get `z.shape = torch.Size([8,8])`

. How would you suggest to change the architecture or what pyro handler would you use, in order to disable this double batching effect?

BTW I would really appreciate some input regarding my other post