with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(model).get_trace()
print(numpyro.util.format_shapes(trace))
Using numpyro.util.format_shapes
I can check the shape of the variable, however, how can I check the shape of the variable when the enumeration is used?