I am quite a bit confused with the way batch and event dimensions are treated when assembling the log_pdf(). As far as I understand their definitions, the event dimensions should be summed over in the log_prob(), while the batch dimensions should remain in the log_prob:
But if I do the same via a model definition, I do not observe the behaviour above. Instead, I observe this:
What confuses me: The log_pdf of the model seems to always behave like it has event dimensions. What’s going on?
More specifically:
- Why do I not see batching over independent dimensions in model_a above?
- Is there a way to examine the shapes of sites inside a model? In Pyro there is trace.format_shapes() command from this tutorial
Tensor shapes in Pyro — Pyro Tutorials 1.8.1 documentation]
Is there an equivalent way in numpyro?
Thank you for any help!