Categorical prior - plate statement

I have a 3d variable x of size (d1, d2, d3) and we would like to realize a categorical sample on the third dimension and create a 2d array (like y) of size (d1,d2). This can be simply written with Numpyro and seems to work outside the model specifications:

y = numpyro.sample("y", rng_key=random.PRNGKey(0), fn=dist.Categorical(x)) # x = (d1,d2,d3), y = (d1,d2)

Now if we use the same command inside a model specification, it throws the following error:

AssertionError: Missing plate statement for batch dimensions at site y

I also tried the same command inside a plate statement but gives the following error (using different dim for the plate):

ValueError: Incompatible shapes for broadcasting: ((1, d1), (d1, d2))

Any help is appreciated!

If the first two dimensions are batch dimensions (i.e. each element of the y array is sampled independently given x), you may need to enclose the site in one pyro.plate context for each batch dimension:

...
with numpyro.plate("d1", d1, dim=-2), numpyro.plate("d2", d2, dim=-1):
    ...
    y = numpyro.sample("y", dist.Categorical(x))
...

See the Pyro tensor shapes tutorial for more background.

1 Like

Thanks @eb8680_2 for the response. The dimensions d1 and d2 are not indeed the batch size.

Let me share a simplified version of my model so that the problem is more lucid:

def BayesianModel(x):
    
    d1, d2, d3 = x.shape 
       
    w_star = numpyro.sample("w_star", dist.Dirichlet(0.001*np.ones(d2)))
    gamma_star = numpyro.sample("gamma_star", dist.Gamma(0.01, 0.01))
    
    
    a_w = numpyro.sample('aws', dist.Categorical(x))

    with numpyro.plate("d1", d1, dim=-1):
        w = numpyro.sample("w", dist.Dirichlet(gamma_star*w_star))
        numpyro.sample('a_w',dist.Multinomial(probs=w), obs=a_w)

I got the following error:

ValueError: Incompatible shapes for broadcasting: ((1, d2), (d2, d1))

I also tried to put a_w in another plate statement as follows:

with numpyro.plate("ds", d1, dim=-2):
        a_w = numpyro.sample('aws', dist.Categorical(x))

with numpyro.plate("d1", d1, dim=-1):
...

But got the following error (apparently cannot have two plates on the same variable):

KeyError: 'd1_2'

Do you have any idea how it can be solved? Thanks in advance!

Sorry, I’m not sure what the problem is. Your code appears to have several errors or inconsistencies (what is aw_blf, and what is its shape? why are you attempting to observe the discrete value a_w with a Dirichlet distribution?), and you don’t seem to have followed my suggestion above to use two plates around a_w. Can you provide a runnable script that reproduces the errors you’re seeing?

@eb8680_2 Sorry for the errors, I had tried to simplify the model. The model in the above message is edited!

I had also implemented your suggestion, but got the following error:

ValueError: Incompatible shapes for broadcasting: ((d1, d2), (d3, 1))

Thanks for taking the time @eb8680_2!