Error while packing tensors, HMM with enumeration

def hmm_model(K, N, observations, concentration, mu_loc, mu_scale, sigma_shape, sigma_scale):
X = {}
Y = {}
theta = pyro.sample("theta", dist.Dirichlet(torch.ones((K, K)) * concentration / K))
mus = pyro.sample(
    "mus", dist.Normal(torch.ones(K) * mu_loc, torch.ones(K) * mu_scale)
)
sigmas = pyro.sample(
    "sigmas", dist.Normal(torch.ones(K) * sigma_shape, torch.ones(K) * sigma_scale)
)
X[0] = pyro.sample(
    "X[0]", dist.Categorical(tensor([1.0] + [0.0] * (K - 1))), obs=tensor(0)
)
Y[0] = pyro.sample(
    "Y[0]", dist.Normal(mus[X[0]], sigmas[X[0]]), obs=observations[0]
)
for i in pyro.markov(range(1, N)):
    X[i] = pyro.sample(
        "X[" + str(i) + "]",
        dist.Categorical(theta[X[i - 1]]),
        infer={"enumerate": "parallel"},
    )
    Y[i] = pyro.sample(
        "Y[" + str(i) + "]",
        dist.Normal(mus[X[i]], sigmas[X[i]]),
        obs=observations[i],
    )

Hi, I am trying to run a simple HMM model in Pyro with NUTS and enumeration for discrete variables. However, I am running into an issue with the shape of theta (transition matrix). Any ideas what the error means?

Also is there any way I can sample the discrete latent states (X[i])?