How/when does providing initial inverse_mass_matrix speed up NUTS warmup?

I’m doing inference of a 6-variable ODE system with 6 free parameters. I noticed that some chains finish their 1000 warmup (and 1000 samples with 0 divergences) within ~1 hour. A few chains take much longer (like 5-7 hours) because they’re hitting max_tree_depth=8 (2^8~256 steps) with small step sizes ~1e-3 given my target_accept=0.8. A single evaluation of my model takes ~0.02 sec and gradient ~0.04 sec so it is rather expensive and I’d like to improve convergence speed.

I suspect this must have something to do with the random initial parameters that are chosen with init_to_sample (my prior predictive checks look fine) – perhaps the curvature in the loss landscape is extremely steep (or flat?) in some dimensions for the chains with problematic/slow initial warmup. I guess I can try to verify this by computing the exact Hessian of the loss function evaluated at those initial parameters for each chain with JAX autodiff.

Would it help speed up step size and mass matrix adaptation during warmup if I further inverted that hessian(loss)(init_params) and fed that to NUTS as inverse_mass_matrix? Rather than starting with an assumed identity mass matrix and trying to slowly learn its structure (I assume dense_mass=True), can NUTS take larger/smarter steps with that initial inverse_mass_matrix guess? (I would still set adapt_mass_matrix=True adapt_step_size=True.)

i think it’s unlikely that hessians will help you because what you really want is a mass matrix that encodes information about the global distribution and not some local hessian which may be arbitrarily different (or even ill defined).

if however you’re running the same/similar problems again and again then it may be beneficial to warm start with a mass matrix obtained from a previous run.

how do i access the mass matrix from another chain that completed? and would that be the final adapted mass matrix from the last warmup iteration?

how do i access the mass matrix from another chain that completed?

dunno off the top of my head but you should be able to find other threads in the forum on the topic or otherwise just look at the code

and would that be the final adapted mass matrix from the last warmup iteration?

yes its frozen at that point.

Thanks @martinjankowiak do you have any other suggestions for comparing things between the chains that finish warmup in ~10 min and successfully converge to the mock truth parameter values vs. the chains that took ~5 hours and got stuck in some other local minimum?

I’ve made trace plots of the logL, parameter values, and norm(grad(loss)) and the individual parameter component grad(loss) vs. warmup/sample iteration #. Nothing stands out currently as to why some chains get stuck.

Maybe the posterior truly is multimodal? But the logL of the local minimum that ~2/10 chains get stuck in is -15 whereas the logL of the truth (where most chains converge to) is -5, so clearly higher. (This is with uniform priors so logL = log posterior.)

hmm it’s a hard problem in general… did you try different init strategies as demo’d e.g. here?

Yes @martinjankowiak I tried init_to_sample and init_to_value.

It turns out there really were two widely separated modes, but I can justify setting a tighter prior to prevent the lower likelihood one, so I think I’ll be fine.

However I noticed that when I do use init_to_value and provide jnp.linalg.inv(hessian(logL)(init_params) as inverse_mass_matrix to NUTS, it finishes 1000 warmup with only ~1 step of ridiculously small size 1e-64 or 1e-300. So the trace of warmup parameter values is basically a flat line (parameters don’t move). Any idea why this might be?

warmup: 100%|██████████| 1000/1000 [03:40<00:00, 4.54it/s, 1 steps of size 2.23e-308. acc. prob=0.00]

The initial Hessian condition # is ~350 and its jnp.linalg.eigvalsh are [-0.02943668 0.02865003 0.07081831 0.21354916 1.86937356 10.64486206] – so not crazy?

i don’t know what’s going on in your particular problem but if your initial mass matrix is such that it thinks a certain direction is very high curvature, then it’ll only make small steps in that direction, and as such it won’t explore that direction very much, and as such it’ll have a hard time figuring out that actually that direction may not be that highly curved after all from a global perspective, and so you’ll have a hard time recovering from your initial poorly chosen mass matrix and sampling will crawl to a halt