I have a complicated jitted model based on shard_map + inner vmap that takes ~0.03 seconds to evaluate, let’s call it eval_func(). This is the function I’ve tried to optimize so much so that I can do MCMC with numpyro in a reasonable amount of time. It is solving ODEs using diffrax.
My numpyro model function is very simple – it draws parameters and calls eval_func(parameters) using a Gaussian likelihood with fixed/input sigma:
def model(obs_data, obs_sigma):
# draw my 4 random parameters with log-uniform priors
p1 = numpyro.sample('p1',dist.Uniform(-2,2))
p2 = numpyro.sample('p2',dist.Uniform(-2,0))
p3 = numpyro.sample('p3',dist.Uniform(-2,0))
p4 = numpyro.sample('p4',dist.Uniform(-2,0))
# evaluate my complicated jitted model given these input parameters
params_rand = jnp.array([p1,p2,p3,p4])
ode_solution = eval_func(params_rand) # diffrax object
# access and compute model prediction from the returned_object object
model_prediction = jnp.log10(ode_solution.ys[0] / ode_solution.ys[1])
# compute gaussian likelihood with fixed/input sigma (rather than inferring sigma)
numpyro.sample('obs', dist.Normal(model_prediction, obs_sigma), obs=obs_data)
Again, eval_func only takes ~0.01 sec for reasonable input params_rand but:
- it takes ~1.3 sec to generate 5 prior predictive samples with
Predictive(model, num_samples=5)(key(1), obs_data=None, obs_sigma=0.1)['obs']
. This means 1.3/5~0.3 sec per prior predictive sample, which is 10x longer than the ~0.03 sec single call to eval_func(). If I instead just a single num_samples=1, then it drops to ~0.05 sec, so maybe this is just a numpyro parallelization issue, but as soon as num_samples=2 or higher it’s always ~1.3 sec. - When using the gradient-free AIES sampler, the progress bar says ~3 “s/it” and it does seem to be taking ~2-3 seconds for the current sample # to increment by 1 iteration. This is 100x slower than a single call to eval_func (again only ~0.03 sec).
- When using the NUTS sampler, it’s usually ~500-1000x slower with ~20 s/it.
My questions:
- Can numpyro’s parameter drawing and likelihood evaluation really take 10-1000x longer than eval_func and dominate the time for a single iteration?
- Again my eval_func is already jitted – I assume numpyro is jitting under the hood when calling mcmc.run()? My numpyro model function suspiciously does not have a @jax.jit decorator. Should I be using
jit_model_args=True
when calling MCMC(AIES/NUTS(model),…)? - How do I check whether numpyro is re-compiling itself internally during every iteration? (My eval_func() does NOT get re-compiled for different values of its input params_rand since the shape/dtype of params_rand is always the same.)
- I’m using 64-bit for jax with
from jax import config; config.update("jax_enable_x64", True)
but have NOT been using numpyro’s ownenable_x64
– could this be causing the massive slowdown in numpyro? - Assuming none of the above issues, does this imply that a “single iteration” of AIES/NUTS actually involves 100-1000 evaluations of eval_func() – how do I see how many times eval_func was called in a single numpyro iteration?
- Or is this just that my model is so complicated that numpyro’s likelihood evaluations really are 10-1000x more time-consuming than eval_func() itself? Is there an easy way for me to benchmark what part of numpyro itself is the slowest and might there be ways to speed those parts up?
I suspect the answer is a combination of the above. The fact that 2+ prior predictive samples takes ~1.3 sec ~ 50x longer than a single eval_func (~0.03 sec) or single prior predictive sample (~0.05 sec) means there is some internal numpyro ~10x overhead just from drawing params_rand. And the fact that the # sec per iteration for AIES (NUTS) is 100x (1000x) slower than a single eval_func suggests that they are calling eval_func 100-1000x for a “single iteration”. Assuming I’ve already maximally optimized eval_func (can’t make a single call much faster than ~0.01 sec), does this imply I’ve reached a floor of how fast numpyro can go? The only other optimizations would have to be messing with default AIES/NUTS sampler options…?