NUTS: estimate runtime using forward-pass time and sec/it metric

Hi,

Thank you again for this wonderful platform.

Recently, when sampling a numpyro model with NUTS, I notice that even if when a model samples with a fixed number of steps (say 2^5), the sec/iteration metric can be very different. For example, the following print out is observed at the different stages of sampling a single model:

warmup:   3%|▎         | 12/400 [09:41<11:49:48, 109.76s/it, 31 steps of size 9.75e-06. acc. prob=0.13]
warmup:   3%|▎         | 13/400 [14:57<18:31:29, 172.32s/it, 31 steps of size 4.97e-05. acc. prob=0.20]
warmup:   4%|▎         | 14/400 [20:12<23:05:25, 215.35s/it, 31 steps of size 2.53e-04. acc. prob=0.26]

versus

warmup:  19%|█▉        | 76/400 [1:49:26<27:09:10, 301.70s/it, 31 steps of size 1.05e-01. acc. prob=0.47]
warmup:  19%|█▉        | 77/400 [1:54:55<27:47:56, 309.83s/it, 31 steps of size 9.54e-02. acc. prob=0.47]
warmup:  20%|█▉        | 78/400 [2:01:04<29:18:20, 327.64s/it, 31 steps of size 2.09e-01. acc. prob=0.48]
...
sample:  27%|██▋       | 109/400 [4:17:03<32:48:59, 405.98s/it, 31 steps of size 5.48e-02. acc. prob=0.49]

In case it might be helpful, I am using max_tree_depth=5 and target_accept_prob=0.5.

Naively, I was expecting that the runtime of a NUTS chain is approximately

(N_warmup + N_sample) * (k * t_forwardpass) * 2^(max_tree_depth) / target_accept_prob

where t_forwardpass is the time that it takes to run the model forward once, and k is the extra factor of time needed for gradient back propagation. Is this intuition correct? It seems like this could not explain why the model takes very different amount of time to execute a single iteration even when the number of integration steps is fixed.

Thank you!
Alan

for NUTS, target accept prob does not affect the computational time. I guess this is related to the speed of low level computation, like time for log(1) and log(2) is different.