Why does a Cholesky-based reparameterization in a multidimensional IRT model converge for small N but fail for large N in NumPyro?

I am fitting a 5-dimensional multidimensional item response theory (MIRT) model in NumPyro.
At the person level, I model the correlation among the 5 latent abilities with an LKJCholesky prior on the Cholesky factor of the correlation matrix.

For the person abilities theta, I tried two parameterizations that are mathematically equivalent:

  • Reparameterized: sample independent standard normals and map them through the Cholesky factor
    (theta_raw ~ Normal(0, 1) and theta = (psi_mat_L @ theta_raw).T).
  • Non-reparameterized: sample theta directly from a multivariate normal
    (MultivariateNormal(0, scale_tril=psi_mat_L).expand([N])).

The item parameters are simulated and then treated as known. Responses follow a 2PL-like logit with Bernoulli observations.

A minimal example is in this gist:

https://gist.github.com/qipengchen/2b9f4c25ab17424f570f273a5c346cfb

Results

Using the same model and priors, I only change:

  • the sample size N (number of persons), and
  • whether I use the reparameterized or non-reparameterized version for theta.

From az.rhat(idata).max() and trace plots, I observe:

  • N = 100 (small sample)
    • Reparameterized: converges well (max r_hat ≈ 1).
    • Non-reparameterized: also converges well.
  • N = 10,000 (large sample)
    • Reparameterized: shows serious convergence issues
      (max r_hat clearly > 1, trace plots indicate poor mixing).
    • Non-reparameterized: still converges, with r_hat close to 1.

So, somewhat surprisingly to me, the “more standard-looking” Cholesky-based reparameterization works fine for small N, but becomes harder to fit as N grows, while the direct multivariate normal version remains usable.

Questions

  1. From a theoretical or numerical point of view, why might the Cholesky-based reparameterization become harder for NUTS to sample from as the number of persons N increases?
  2. In this setting, is it reasonable to simply use the non-reparameterized version (MultivariateNormal(...).expand([N])) as a practical solution, or is that likely to hide deeper issues?
  3. Are there other, more suitable ways to handle this kind of model in NumPyro when N is large?

Any insight, intuition, or pointers to references would be very much appreciated. Thank you for your time!