I am training an HMM-like model where the transition logits are computed based on the previous state and some learned embeddings for each sequence and passed to pyro.distributions.Categorical.
z[t] = pyro.sample(f"z_{t}", dist.Categorical(logits=logits))
During sampling, the shape of logits is (64, 1, 5) and during enumeration the shape of logits is (5, 1, 64, 1, 5).
I get the following error after 1 batch (sampling and enumeration works but when the log prob is computed it fails) which comes from the source code for the Categorical distribution:
File /opt/anaconda3/envs/cellfate/lib/python3.12/site-packages/pyro/distributions/torch.py:141, in Categorical.log_prob(self, value)
137 logits = logits.reshape(
138 (1,) * (1 + value.dim() - logits.dim()) + logits.shape
139 )
140 if not torch._C._get_tracing_state():
--> 141 assert logits.size(-1 - value.dim()) == 1
142 return logits.transpose(-1 - value.dim(), -1).squeeze(-1)
143 return super().log_prob(value)
AssertionError:
I’m wondering what this could be caused by as I’m not certain what exactly this assertion is checking. Hoping the pyro devs might have some insights into what the assertion is trying to do.