Unexpected "Missing Plate" warning

Hi all,

I apologize in advance; this is probably a trivial issue but I currently cannot see the solution.

While setting up a model and performing MCMC/NUTS inference on it, I am encountering a “UserWarning: Missing a plate statement for batch dimension”. I had this before and was able to resolve it, but in this case I just don’t see what is wrong / how to solve it.

Imports
import numpyro
import numpyro.distributions as dist
from numpyro import sample, plate
from numpyro.infer import Predictive, NUTS, MCMC
from jax import random
import numpy as np
from icecream import ic
def model(loc=None, y=None):
    if loc is None:
        loc = sample("loc", dist.Normal(loc=0.0, scale=2.0))

    if y is None:
        nobs = 1
    else:
        nobs = len(y)

    with plate("N", nobs):
        lat = sample("lat", dist.Normal(loc=np.ones((nobs,)) * loc, scale=1.0))  # unobserved
        label_dist = dist.Bernoulli(logits=lat)
        ic(label_dist.batch_shape)  # just for debugging
        ic(label_dist.event_shape)  # just for debugging
        label = sample("label", label_dist, obs=y)
        ic(label.shape)  # just for debugging

Forward-simulating the model works fine:

data = Predictive(lambda **kwargs: model(loc=-1), num_samples=1000)(random.PRNGKey(1))

The trace shapes look exactly as I would expect:

with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(model).get_trace(y=np.ones((15,)))
print(numpyro.util.format_shapes(trace))

yields

Trace Shapes:     
 Param Sites:     
Sample Sites:     
     loc dist    |
        value    |
      N plate 15 |
     lat dist 15 |
        value 15 |
   label dist 15 |
        value 15 |

But when I run MCMC I get "UserWarning: Missing a plate statement for batch dimension -2 at site ‘label’.

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, num_chains=4, num_warmup=500)
mcmc.run(random.PRNGKey(0), y=data["label"])

For what it’s worth, inference seems to have worked reasonably well:

mcmc.get_samples()['loc'].mean()  # yields -0.86 vs. true value -1.0.

(I am actually interested in a more complex model, and I am worried that things going wrong with the plate dimensions might mess up inference in that more complex model. Hence my question even though inference seems to work fine here.)

During the MCMC run, the ic calls (it’s basically just a slightly nicer print(x)) yield the following:

ic| label_dist.batch_shape: (1000,)
ic| label_dist.event_shape: ()
ic| label.shape: (1000, 1)

I thought that the issue might be somehow related to the last dimension of the label variable not being empty (why?), so I tried adding a .reshape(-1) to the label sample statement. That changes the shape to (1000,) but the missing plate warning is still produced – even though the label variable doesn’t even have two dimensions in that case.

Any ideas or suggestions? Many thanks in advance for any help!

I think you should look for the shape of data["label"].

1 Like

Ahh yes that was it, thank you so much!! (I really appreciate your taking the time to respond to beginner questions like this - thank you!)