Assertion error in pyro Categorical distribution

I am training an HMM-like model where the transition logits are computed based on the previous state and some learned embeddings for each sequence and passed to pyro.distributions.Categorical.

z[t] = pyro.sample(f"z_{t}", dist.Categorical(logits=logits))

During sampling, the shape of logits is (64, 1, 5) and during enumeration the shape of logits is (5, 1, 64, 1, 5).

I get the following error after 1 batch (sampling and enumeration works but when the log prob is computed it fails) which comes from the source code for the Categorical distribution:

File /opt/anaconda3/envs/cellfate/lib/python3.12/site-packages/pyro/distributions/torch.py:141, in Categorical.log_prob(self, value)
    137         logits = logits.reshape(
    138             (1,) * (1 + value.dim() - logits.dim()) + logits.shape
    139         )
    140     if not torch._C._get_tracing_state():
--> 141         assert logits.size(-1 - value.dim()) == 1
    142     return logits.transpose(-1 - value.dim(), -1).squeeze(-1)
    143 return super().log_prob(value)

AssertionError: 

I’m wondering what this could be caused by as I’m not certain what exactly this assertion is checking. Hoping the pyro devs might have some insights into what the assertion is trying to do.

Solved - the issue is caused by calling moveaxis on logits which seems to mess up the log probs computation.