I apologize in advance; this is probably a trivial issue but I currently cannot see the solution.
While setting up a model and performing MCMC/NUTS inference on it, I am encountering a “UserWarning: Missing a plate statement for batch dimension”. I had this before and was able to resolve it, but in this case I just don’t see what is wrong / how to solve it.
import numpyro import numpyro.distributions as dist from numpyro import sample, plate from numpyro.infer import Predictive, NUTS, MCMC from jax import random import numpy as np from icecream import ic
def model(loc=None, y=None): if loc is None: loc = sample("loc", dist.Normal(loc=0.0, scale=2.0)) if y is None: nobs = 1 else: nobs = len(y) with plate("N", nobs): lat = sample("lat", dist.Normal(loc=np.ones((nobs,)) * loc, scale=1.0)) # unobserved label_dist = dist.Bernoulli(logits=lat) ic(label_dist.batch_shape) # just for debugging ic(label_dist.event_shape) # just for debugging label = sample("label", label_dist, obs=y) ic(label.shape) # just for debugging
Forward-simulating the model works fine:
data = Predictive(lambda **kwargs: model(loc=-1), num_samples=1000)(random.PRNGKey(1))
The trace shapes look exactly as I would expect:
with numpyro.handlers.seed(rng_seed=1): trace = numpyro.handlers.trace(model).get_trace(y=np.ones((15,))) print(numpyro.util.format_shapes(trace))
Trace Shapes: Param Sites: Sample Sites: loc dist | value | N plate 15 | lat dist 15 | value 15 | label dist 15 | value 15 |
But when I run MCMC I get "UserWarning: Missing a plate statement for batch dimension -2 at site ‘label’.
nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=1000, num_chains=4, num_warmup=500) mcmc.run(random.PRNGKey(0), y=data["label"])
For what it’s worth, inference seems to have worked reasonably well:
mcmc.get_samples()['loc'].mean() # yields -0.86 vs. true value -1.0.
(I am actually interested in a more complex model, and I am worried that things going wrong with the plate dimensions might mess up inference in that more complex model. Hence my question even though inference seems to work fine here.)
During the MCMC run, the
ic calls (it’s basically just a slightly nicer
print(x)) yield the following:
ic| label_dist.batch_shape: (1000,) ic| label_dist.event_shape: () ic| label.shape: (1000, 1)
I thought that the issue might be somehow related to the last dimension of the label variable not being empty (why?), so I tried adding a
.reshape(-1) to the label sample statement. That changes the shape to
(1000,) but the missing plate warning is still produced – even though the label variable doesn’t even have two dimensions in that case.
Any ideas or suggestions? Many thanks in advance for any help!