There is no while_loop
in the model. There is a fori_loop
(imported below as foril
).
Strangely, the slowest part of the model seems to be this function
NAX = jnp.newaxis
def get_eig_corr(clp, z1):
return (clp.conj() * ((z1 * clp[:, NAX, :]).sum(axis=0))).sum(axis=0)
which is called within a fori_loop
in the model:
def model():
c_arr = numpyro.sample(f'c_arr', dist.Uniform(cmin, cmax))
pred_acoeffs = jnp.zeros(num_j * nmults)
c_params = c_arr * true_params
def scale_bkm(mult_idx, bkm_full):
bkm_full = jidx_update(bkm_full,
jidx[i, :, :, :],
-1.0*bkm_full[mult_idx, ...]/dom_dell_jax[mult_idx])
return bkm_full
z0 = param_coeff_M @ c_params + fixed_part_M
zfull = param_coeff @ c_params + fixed_part
bkm = param_coeff_bkm @ c_params + fixed_part_bkm
bkm = foril(0, nmults, scale_bkm, bkm)
clp = get_clp(bkm)
def loop_in_mults(mult_ind, pred_acoeff):
ell0 = ell0_arr_jax[mult_ind]
omegaref = omega0_arr_jax[mult_ind]
z0mult = z0[mult_ind]
z1mult = zfull[mult_ind]/2./omegaref - z0mult
_eigval1mult = get_eig_corr(clp[mult_ind], z1mult)*GVARS.OM*1e6
Pjl_local = Pjl[mult_ind]
pred_acoeff = jdc_update(pred_acoeff,
(Pjl_local @ _eigval1mult)/Pjl_norm[mult_ind],
(mult_ind * num_j,))
return pred_acoeff
pred_acoeffs = foril(0, nmults, loop_in_mults, pred_acoeffs)
misfit_acoeffs = (pred_acoeffs - acoeffs_true)/acoeffs_sigma
return numpyro.factor('obs', dist.Normal(0.0, 1.0).log_prob(misfit_acoeffs))
Instead of computing get_eig_corr
, if I return jnp.ones()
of appropriate shape, the model computation takes about 1e-4 seconds, but with get_eig_corr
it takes about 2e-2 seconds. I called it strange because it seems like a simple enough function that shouldn’t take a lot of time to compute. In fact, explicitly timing it’s execution takes about 5e-4 seconds, but something strange happens when called from within the code.