Problem with enumeration for batch data in Deep markov models

Another update:

  • I removed the obs_plate (as my output space is single-dimensional) s.t. right dimension is for batch
  • used Vindex for sampling z

The code:

def model( self, sequences, include_prior=True):
        ...
        pyro.module("state_emitter", self.state_emitter)
        pyro.module("ar_emitter", self.ar_emitter)

        with poutine.mask(mask=include_prior):
            # transition matrix in the hidden state [ num_states X num_states ]
            probs_lat = pyro.sample("probs_lat", dist.Dirichlet(
                0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1))
        with pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size) as batch:
            lengths = self.lengths[batch]
            z = 0
            y = torch.zeros(self.args.batch_size, 1)
            input_batch = input_seq[batch, :]
            for t in pyro.markov(range(0, self.max_lenght if self.args.jit else self.lengths.max())):
                with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                                      
                    px = self.state_emitter(input_batch[:, t, :], z) # px.shape = [batch_size X num_states]
                    emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(px)) # emitted_x.shape = [batch_size X 1]

                    z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[...,emitted_x,:])) z.shape = [batch_size X 1]

                    py = self.ar_emitter(y, z) # px.shape = [batch_size X num_emission]
                    y = pyro.sample(f"y_{t}", dist.Categorical(py),obs=output_seq[batch, t])

both runs for the guide and enumeration seems fine.
The trace shapes info is

...
      Sample Sites:                       
     probs_lat dist                  | 6 6
              value                  | 6 6
 sequence_list dist                  |    
              value             30   |    
      emit_x_0 dist             30   |    
              value       6  1   1   |    
           z_0 dist       6  1  30   |    
              value     6 1  1   1   |    
           y_0 dist       6  1  30   |    
              value         30   1   |    
      emit_x_1 dist       6  1  30   |    
              value   6 1 1  1   1   |    
           z_1 dist   6 1 1  1  30   |    
              value 6 1 1 1  1   1   |    
           y_1 dist   6 1 1  1  30   |    
              value         30   1   |  
...

However, I receive in the svi.step the following bug:

ValueError: at site "emit_x_0", invalid log_prob shape
  Expected [30], actual [30, 30]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions
# the svi step
        self.elbo = Elbo(max_plate_nesting=2)
        optim = Adam({'lr': self.args.learning_rate})
        svi = SVI(self.model, self.guide, optim, self.elbo)

        # We'll train on small minibatches.
        self.logger.info('Step\tLoss')
        for step in range(self.args.num_steps):
            loss = svi.step(self.sequences)

I don’t really understand what’s wrong, given that the enumeration dimensions make sense.
I can also add the code for ar_emiter and state_emitter, however, I doubt that the problem is there.
Is there something wrong with the plates or to_event usage?