Trailing vs leading cluster dims

I’m building a NumPyro model with linear algebra batches over clusters with hierarchical priors. I solve TxC (T days and C clusters) OLS equations over N data points using a single set of matrix operations (no loop or vmap).

Because I use plate for hierarchical priors, the sampled per-cluster parameters (and downstream broadcasts) naturally led me to keep cluster dimensions as trailing axes, e.g. my data is (T,N,C)

Talking to ChatGPT, it suggested that jax is able to much better optimize the linear algebra portion when the batch dims are leading. It suggests I refactor the tensor layout so that batch dimensions T and C are leading, e.g.(C, T, N). It’s recommended pattern is:

  • use plate only for sampling per-cluster random variables
  • broadcast / reshape those parameters to match the compute layout
  • Input / slice the data using those parameters and run all linear algebra using leading batch dims

Questions:

  1. Is it true that leading batch dimensions (e.g. (C,T,N)) are materially more performant than trailing cluster dimensions (e.g. (T,N,C)) for batched linear algebra under JAX/XLA? Or is ChatGPT hallucinating ?
  2. If the answer to the above is yes, given my existing code am I OK to just slap a transpose ahead of the linear algebra portion? Or is it important to refactor so tensors are natively produced in (C,T,N) ? The output of the linear algebra portion is used in the log likelihood, so the transpose would happen at each model eval
  3. are there any pitfalls or cases where keeping cluster dims trailing is preferable?

Thanks !