Hi,

Thank you again for this wonderful platform.

Recently, when sampling a numpyro model with NUTS, I notice that even if when a model samples with a fixed number of steps (say 2^5), the sec/iteration metric can be very different. For example, the following print out is observed at the different stages of sampling a single model:

```
warmup: 3%|▎ | 12/400 [09:41<11:49:48, 109.76s/it, 31 steps of size 9.75e-06. acc. prob=0.13]
warmup: 3%|▎ | 13/400 [14:57<18:31:29, 172.32s/it, 31 steps of size 4.97e-05. acc. prob=0.20]
warmup: 4%|▎ | 14/400 [20:12<23:05:25, 215.35s/it, 31 steps of size 2.53e-04. acc. prob=0.26]
```

versus

```
warmup: 19%|█▉ | 76/400 [1:49:26<27:09:10, 301.70s/it, 31 steps of size 1.05e-01. acc. prob=0.47]
warmup: 19%|█▉ | 77/400 [1:54:55<27:47:56, 309.83s/it, 31 steps of size 9.54e-02. acc. prob=0.47]
warmup: 20%|█▉ | 78/400 [2:01:04<29:18:20, 327.64s/it, 31 steps of size 2.09e-01. acc. prob=0.48]
...
sample: 27%|██▋ | 109/400 [4:17:03<32:48:59, 405.98s/it, 31 steps of size 5.48e-02. acc. prob=0.49]
```

In case it might be helpful, I am using `max_tree_depth=5`

and `target_accept_prob=0.5`

.

Naively, I was expecting that the runtime of a NUTS chain is approximately

(N_warmup + N_sample) * (k * t_forwardpass) * 2^(max_tree_depth) / target_accept_prob

where t_forwardpass is the time that it takes to run the model forward once, and k is the extra factor of time needed for gradient back propagation. Is this intuition correct? It seems like this could not explain why the model takes very different amount of time to execute a single iteration even when the number of integration steps is fixed.

Thank you!

Alan