Hi all,
I apologize in advance; this is probably a trivial issue but I currently cannot see the solution.
While setting up a model and performing MCMC/NUTS inference on it, I am encountering a “UserWarning: Missing a plate statement for batch dimension”. I had this before and was able to resolve it, but in this case I just don’t see what is wrong / how to solve it.
Imports
import numpyro
import numpyro.distributions as dist
from numpyro import sample, plate
from numpyro.infer import Predictive, NUTS, MCMC
from jax import random
import numpy as np
from icecream import ic
def model(loc=None, y=None):
if loc is None:
loc = sample("loc", dist.Normal(loc=0.0, scale=2.0))
if y is None:
nobs = 1
else:
nobs = len(y)
with plate("N", nobs):
lat = sample("lat", dist.Normal(loc=np.ones((nobs,)) * loc, scale=1.0)) # unobserved
label_dist = dist.Bernoulli(logits=lat)
ic(label_dist.batch_shape) # just for debugging
ic(label_dist.event_shape) # just for debugging
label = sample("label", label_dist, obs=y)
ic(label.shape) # just for debugging
Forward-simulating the model works fine:
data = Predictive(lambda **kwargs: model(loc=-1), num_samples=1000)(random.PRNGKey(1))
The trace shapes look exactly as I would expect:
with numpyro.handlers.seed(rng_seed=1):
trace = numpyro.handlers.trace(model).get_trace(y=np.ones((15,)))
print(numpyro.util.format_shapes(trace))
yields
Trace Shapes:
Param Sites:
Sample Sites:
loc dist |
value |
N plate 15 |
lat dist 15 |
value 15 |
label dist 15 |
value 15 |
But when I run MCMC I get "UserWarning: Missing a plate statement for batch dimension -2 at site ‘label’.
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, num_chains=4, num_warmup=500)
mcmc.run(random.PRNGKey(0), y=data["label"])
For what it’s worth, inference seems to have worked reasonably well:
mcmc.get_samples()['loc'].mean() # yields -0.86 vs. true value -1.0.
(I am actually interested in a more complex model, and I am worried that things going wrong with the plate dimensions might mess up inference in that more complex model. Hence my question even though inference seems to work fine here.)
During the MCMC run, the ic
calls (it’s basically just a slightly nicer print(x)
) yield the following:
ic| label_dist.batch_shape: (1000,)
ic| label_dist.event_shape: ()
ic| label.shape: (1000, 1)
I thought that the issue might be somehow related to the last dimension of the label variable not being empty (why?), so I tried adding a .reshape(-1)
to the label sample statement. That changes the shape to (1000,)
but the missing plate warning is still produced – even though the label variable doesn’t even have two dimensions in that case.
Any ideas or suggestions? Many thanks in advance for any help!