Slow runs for larger iterations

Hello numpyro devs and users,

I’ve noticed that the numpyro run with (using NUTS) slows down considerably when the number of iterations is increased. The same code run for 30 iterations and 2500 iterations takes 13s/it and 144s/it respectively during the warmup stage (with max_tree_depth=5)

warmup:  30%|█▌  | 9/30 [01:59<04:29, 12.85s/it, 31 steps 
                         of size 7.07e-03. acc. prob=0.67]
warmup:   1%|▏   | 33/2500 [1:03:02<98:43:46, 144.07s/it, 31 steps 
                            of size 1.40e-02. acc. prob=0.76]

I would naively assume that the time taken per iteration would be the same in both the cases. Is this behaviour expected? Are there ways to correct this so that I get the right speeds for the larger iteration run?


please keep in mind that NUTS is an adaptive algorithm. the number of steps taken per iteration changes from iteration to iteration. as the algorithm explores different parts of the posterior more/less gradient steps may be needed. thus execution time will be variable. as such this is expected and i don’t see any reason to think this reflects a software issue. in order to see convincing evidence that this is e.g. a software issue you would need to run very very long chains (where the variation from iteration to iteration washes out as you explore large parts of the the space over many iterations) and show that e.g. a chain with 200k iterations takes substantially more than twice the time than a chain with 100k iterations

Thanks @martinjankowiak. I should’ve made it more clear. From my observation, it seems that the number of steps taken during warmup is 31 (as I’ve set max_tree_depth to be 5) in both the cases. This is also indicated in the snippet of the progress tracker.

What I noticed is that, for multiple runs of 30 iterations, the entire run was completed in about 5 to 6 minutes. In every case, I observed that the number of steps taken was 31 (the first few iterations take a small number of steps, but after about iteration number 6, the number of adaptive steps is consistently 31). Since, this also happens to be the maximum number allowed by the imposed tree_depth (31 = 2^5 -1), doesn’t it mean that the time taken for 31 steps per iteration should be constant across runs? For the case of 2500 iterations also, the number of adaptive steps taken per iteration is 31, but if you notice the snippet I posted in the question, for 33 iterations, the larger run has taken more than an hour, but the smaller run is expected to complete in 5 min (and it does complete in the time).

I’m not claiming that this is a software issue. But I’m merely trying to understand what could be the reason for this and how I can mitigate it. Does it also depend on where in parameter space the particular chain is located? Are there more steps happening (under the hood) for each adaptive step, which could possibly slow it down?

Could you try to use the same num_samples or disable adaptation (either step size or mass matrix)? Sometimes, calculations with small numbers are faster than with the large numbers (or reversely), but that might not cause such huge different. Maybe be your model has a lot of parameters and the cost to allocate spaces for more samples is expensive?

Thanks @fehiepsi. I’ll try that out.

To answer your other question, the model has 8 parameters. I’ve checked the memory utilization during both the runs and they seem to take less than 10% of the available memory.

I guess the size of allocated spaces for much more samples does not matter for this slow model. Probably this is a random seed issue? Are you using ode in your model, which involves while_loop whose speed cannot be estimated properly? Another assumption is some operator in your model is using while_loop and it hits some tricky space to evaluate/take gradient (e.g. gammaln). If you let us know why your model is slow, maybe we could provide better hints.

There is no while_loop in the model. There is a fori_loop (imported below as foril).
Strangely, the slowest part of the model seems to be this function

NAX = jnp.newaxis
def get_eig_corr(clp, z1):
    return (clp.conj() * ((z1 * clp[:, NAX, :]).sum(axis=0))).sum(axis=0)

which is called within a fori_loop in the model:

def model():
    c_arr = numpyro.sample(f'c_arr', dist.Uniform(cmin, cmax))
    pred_acoeffs = jnp.zeros(num_j * nmults)
    c_params = c_arr * true_params

    def scale_bkm(mult_idx, bkm_full):
        bkm_full = jidx_update(bkm_full,
                               jidx[i, :, :, :],
                               -1.0*bkm_full[mult_idx, ...]/dom_dell_jax[mult_idx])
        return bkm_full

    z0 = param_coeff_M @ c_params + fixed_part_M
    zfull = param_coeff @ c_params + fixed_part
    bkm = param_coeff_bkm @ c_params + fixed_part_bkm
    bkm = foril(0, nmults, scale_bkm, bkm)
    clp = get_clp(bkm)

    def loop_in_mults(mult_ind, pred_acoeff):
        ell0 = ell0_arr_jax[mult_ind]
        omegaref = omega0_arr_jax[mult_ind]

        z0mult = z0[mult_ind]
        z1mult = zfull[mult_ind]/2./omegaref - z0mult

        _eigval1mult = get_eig_corr(clp[mult_ind], z1mult)*GVARS.OM*1e6

        Pjl_local = Pjl[mult_ind]
        pred_acoeff = jdc_update(pred_acoeff,
                                 (Pjl_local @ _eigval1mult)/Pjl_norm[mult_ind],
                                 (mult_ind * num_j,))
        return pred_acoeff

    pred_acoeffs = foril(0, nmults, loop_in_mults, pred_acoeffs)
    misfit_acoeffs = (pred_acoeffs - acoeffs_true)/acoeffs_sigma
    return numpyro.factor('obs', dist.Normal(0.0, 1.0).log_prob(misfit_acoeffs))

Instead of computing get_eig_corr, if I return jnp.ones() of appropriate shape, the model computation takes about 1e-4 seconds, but with get_eig_corr it takes about 2e-2 seconds. I called it strange because it seems like a simple enough function that shouldn’t take a lot of time to compute. In fact, explicitly timing it’s execution takes about 5e-4 seconds, but something strange happens when called from within the code.