There is no `while_loop`

in the model. There is a `fori_loop`

(imported below as `foril`

).

Strangely, the slowest part of the model seems to be this function

```
NAX = jnp.newaxis
def get_eig_corr(clp, z1):
return (clp.conj() * ((z1 * clp[:, NAX, :]).sum(axis=0))).sum(axis=0)
```

which is called within a `fori_loop`

in the model:

```
def model():
c_arr = numpyro.sample(f'c_arr', dist.Uniform(cmin, cmax))
pred_acoeffs = jnp.zeros(num_j * nmults)
c_params = c_arr * true_params
def scale_bkm(mult_idx, bkm_full):
bkm_full = jidx_update(bkm_full,
jidx[i, :, :, :],
-1.0*bkm_full[mult_idx, ...]/dom_dell_jax[mult_idx])
return bkm_full
z0 = param_coeff_M @ c_params + fixed_part_M
zfull = param_coeff @ c_params + fixed_part
bkm = param_coeff_bkm @ c_params + fixed_part_bkm
bkm = foril(0, nmults, scale_bkm, bkm)
clp = get_clp(bkm)
def loop_in_mults(mult_ind, pred_acoeff):
ell0 = ell0_arr_jax[mult_ind]
omegaref = omega0_arr_jax[mult_ind]
z0mult = z0[mult_ind]
z1mult = zfull[mult_ind]/2./omegaref - z0mult
_eigval1mult = get_eig_corr(clp[mult_ind], z1mult)*GVARS.OM*1e6
Pjl_local = Pjl[mult_ind]
pred_acoeff = jdc_update(pred_acoeff,
(Pjl_local @ _eigval1mult)/Pjl_norm[mult_ind],
(mult_ind * num_j,))
return pred_acoeff
pred_acoeffs = foril(0, nmults, loop_in_mults, pred_acoeffs)
misfit_acoeffs = (pred_acoeffs - acoeffs_true)/acoeffs_sigma
return numpyro.factor('obs', dist.Normal(0.0, 1.0).log_prob(misfit_acoeffs))
```

Instead of computing `get_eig_corr`

, if I return `jnp.ones()`

of appropriate shape, the model computation takes about 1e-4 seconds, but with `get_eig_corr`

it takes about 2e-2 seconds. I called it strange because it seems like a simple enough function that shouldn’t take a lot of time to compute. In fact, explicitly timing it’s execution takes about 5e-4 seconds, but something strange happens when called from within the code.