I have a 3d variable x of size (d1, d2, d3) and we would like to realize a categorical sample on the third dimension and create a 2d array (like y) of size (d1,d2). This can be simply written with Numpyro and seems to work outside the model specifications:
y = numpyro.sample("y", rng_key=random.PRNGKey(0), fn=dist.Categorical(x)) # x = (d1,d2,d3), y = (d1,d2)
Now if we use the same command inside a model specification, it throws the following error:
AssertionError: Missing plate statement for batch dimensions at site y
I also tried the same command inside a plate statement but gives the following error (using different dim for the plate):
ValueError: Incompatible shapes for broadcasting: ((1, d1), (d1, d2))
Any help is appreciated!