Problem with enumeration for batch data in Deep markov models

@eb8680_2 here is a runnable example of the code. I really tried to make close to the original problem.
Here I am trying to work with batch_dim = -1 and I figured out that the problem seems to be the masking and specifically the .unsqueeze(-1).
Do you think I should drop the .unsqueeze(-1)? How should I move to a model with batch_dim=-2?

with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length).unsqueeze(-1))
from unittest import TestCase
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal, AutoDelta
from pyro.ops.indexing import Vindex

pyro.enable_validation()
pyro.set_rng_seed(0)
num_seq = 100
seq_length = 1000

import torch
from torch import nn as nn


class TestEmitter(nn.Module):
    def __init__(self, w_dim=1, z_dim=1):
        super(TestEmitter, self).__init__()
        self.lin_w_to_hidden = nn.Linear(w_dim, z_dim)
        self.softmax = nn.Softmax(dim=0)

    def forward(self, w, z):
        w = w.unsqueeze(-1)
        w_hidden = self.lin_w_to_hidden(w)
        out = w_hidden
        out_sm = self.softmax(out)  # alpha is size of z
        return out_sm


@config_enumerate
def model(inp):
    num_states = 6
    batch_size = 30
    x = inp['x']
    y = inp['y']

    state_emitter = TestEmitter(w_dim=1, z_dim=6)
    obs_emitter = TestEmitter(w_dim=1, z_dim=25)
    pyro.module("state_emitter", state_emitter)
    pyro.module("obs_emitter", obs_emitter)

    with poutine.mask(mask=True):
        probs_lat = pyro.sample("probs_lat",
                                dist.Dirichlet(0.5 * torch.eye(num_states) + 0.5 / (num_states - 1)).to_event(1))

    z = torch.Tensor([0]).type(torch.FloatTensor)
    y_hat = torch.Tensor([0]).type(torch.FloatTensor)
    with pyro.plate("sequence", size=num_seq, subsample_size=batch_size) as batch:
        for t in pyro.markov(range(0, seq_length - 1)):
            # with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length)):
            with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length).unsqueeze(-1)):
                px = state_emitter(x[t, batch].type(torch.FloatTensor), z)
                z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[..., px.argmax(dim=1), :]))
                py = obs_emitter(z.type(torch.FloatTensor), y_hat)
                y_hat = pyro.sample(f"y_{t}", dist.Categorical(py), obs=y[t, batch])


def generate_data():
    x = torch.randint(0, 30, size=(seq_length, num_seq))
    y = torch.randint(0, 25, size=(seq_length, num_seq))
    seq = dict(x=x, y=y)
    return seq


def print_shapes(model, guide, seq, first_available_dim=-3):
    guide_trace = poutine.trace(guide).get_trace(seq)
    model_trace = poutine.trace(
        poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(seq)
    print(model_trace.format_shapes())


class TestPyroARDBN(TestCase):

    def test_architecture(self):
        seq = generate_data()
        guide = AutoDelta(poutine.block(model, expose=["probs_lat"]))
        pyro.clear_param_store()
        elbo = TraceEnum_ELBO(max_plate_nesting=3)
        print_shapes(model, guide, seq, first_available_dim=-3)
        elbo.loss(model, guide, seq)