I am fitting a 5-dimensional multidimensional item response theory (MIRT) model in NumPyro.
At the person level, I model the correlation among the 5 latent abilities with an LKJCholesky prior on the Cholesky factor of the correlation matrix.
For the person abilities theta, I tried two parameterizations that are mathematically equivalent:
- Reparameterized: sample independent standard normals and map them through the Cholesky factor
(theta_raw ~ Normal(0, 1)andtheta = (psi_mat_L @ theta_raw).T). - Non-reparameterized: sample
thetadirectly from a multivariate normal
(MultivariateNormal(0, scale_tril=psi_mat_L).expand([N])).
The item parameters are simulated and then treated as known. Responses follow a 2PL-like logit with Bernoulli observations.
A minimal example is in this gist:
https://gist.github.com/qipengchen/2b9f4c25ab17424f570f273a5c346cfb
Results
Using the same model and priors, I only change:
- the sample size
N(number of persons), and - whether I use the reparameterized or non-reparameterized version for
theta.
From az.rhat(idata).max() and trace plots, I observe:
- N = 100 (small sample)
- Reparameterized: converges well (max
r_hat≈ 1). - Non-reparameterized: also converges well.
- Reparameterized: converges well (max
- N = 10,000 (large sample)
- Reparameterized: shows serious convergence issues
(maxr_hatclearly > 1, trace plots indicate poor mixing). - Non-reparameterized: still converges, with
r_hatclose to 1.
- Reparameterized: shows serious convergence issues
So, somewhat surprisingly to me, the “more standard-looking” Cholesky-based reparameterization works fine for small N, but becomes harder to fit as N grows, while the direct multivariate normal version remains usable.
Questions
- From a theoretical or numerical point of view, why might the Cholesky-based reparameterization become harder for NUTS to sample from as the number of persons N increases?
- In this setting, is it reasonable to simply use the non-reparameterized version (
MultivariateNormal(...).expand([N])) as a practical solution, or is that likely to hide deeper issues? - Are there other, more suitable ways to handle this kind of model in NumPyro when N is large?
Any insight, intuition, or pointers to references would be very much appreciated. Thank you for your time!