After enabling 64bits precision on JAX arrays, I have a complex hierarchical multinomial logit model working properly with Numpyro’s NUTS sampler.
Now, I have some questions about tuning on NUTS sampler / my model:
- Signal of tuning on max_tree_depth.
I know that increasing the depth will lead to better performance and decreasing it usually means some improve on computation efficiency. When I run the adaptive warmup with a default max_tree_depth (10), the number of steps get steady to 1,023 after some beginning iterations.
Can I treat this as a signal that I should increase the max_tree_depth to 11 or so? Similarly, shall I decrease the max_tree_depth to 9 if I observe the number of steps lower than 511? Are there some rules of thumb on tuning it?
- GPU memory cannot afford dense_mass for complex model.
When I browsing past posts in the forum, I found that enabling dense_mass is often mentioned together with tuning max_tree_depth for NUTS setting. In my case, I tried this but the model’s complexity seems too high that even a GPU with 48GB memory cannot afford a dense_mass matrix.
Are there other general suggestions on tuning NUTS sampler when dense_mass is not applicable?
- Efficiency / Speed Concerns
I know that the NUTS sampler is adaptive on step numbers, step sizes and of course the running time. In my case, my model get stucked at very early iterations before enabling 64bits precision and in that case the warmup and sampling speed appears to be very fast – about 50-70it/s. After enabling the higher precision, the speed becomes very slow – more than 6s/it when number of steps is 1,023. I ran a warmup of 7,500 iterations in 6.5hrs and the model still not converged. Currently I’m trying to run a longer chain but it will also take much more time to fit.
I want to know if there are some general suggestions or practical tricks to have a model fitted efficiently? I will be okay with e.g., customized initial values, parallel running, other internal/external samplers, SVI, multi-stage inference, etc. But I need to have the chain converged – I care about posterior details for decision making.
- Improving efficiency of the model code
I also want to improve the efficiency by optimizing my model code. Actually I did some work on this, e.g., writing vectorized computation code, having some functions jitted, using jax control flow instead of python’s, etc.
I wonder that if there is a tool in Numpyro that I can work with to figure out which seg/part of my code takes most time during the inference so that I can refine it in a targeted manner. (Other general suggestions on improving the model code are also welcome.)
Sorry for asking so many questions at once. Thank you in advance for any feedback!