Why AIES/NUTS time/iteration 10-100x slower than single call to model?

I have a complicated jitted model based on shard_map + inner vmap that takes ~0.03 seconds to evaluate, let’s call it eval_func(). This is the function I’ve tried to optimize so much so that I can do MCMC with numpyro in a reasonable amount of time. It is solving ODEs using diffrax.

My numpyro model function is very simple – it draws parameters and calls eval_func(parameters) using a Gaussian likelihood with fixed/input sigma:

def model(obs_data, obs_sigma):

    # draw my 4 random parameters with log-uniform priors
    p1 = numpyro.sample('p1',dist.Uniform(-2,2)) 
    p2 = numpyro.sample('p2',dist.Uniform(-2,0)) 
    p3 = numpyro.sample('p3',dist.Uniform(-2,0)) 
    p4 = numpyro.sample('p4',dist.Uniform(-2,0)) 

    # evaluate my complicated jitted model given these input parameters
    params_rand = jnp.array([p1,p2,p3,p4])
    ode_solution = eval_func(params_rand) # diffrax object

    # access and compute model prediction from the returned_object object
    model_prediction = jnp.log10(ode_solution.ys[0] / ode_solution.ys[1])

    # compute gaussian likelihood with fixed/input sigma (rather than inferring sigma)
    numpyro.sample('obs', dist.Normal(model_prediction, obs_sigma), obs=obs_data)

Again, eval_func only takes ~0.01 sec for reasonable input params_rand but:

  1. it takes ~1.3 sec to generate 5 prior predictive samples with Predictive(model, num_samples=5)(key(1), obs_data=None, obs_sigma=0.1)['obs'] . This means 1.3/5~0.3 sec per prior predictive sample, which is 10x longer than the ~0.03 sec single call to eval_func(). If I instead just a single num_samples=1, then it drops to ~0.05 sec, so maybe this is just a numpyro parallelization issue, but as soon as num_samples=2 or higher it’s always ~1.3 sec.
  2. When using the gradient-free AIES sampler, the progress bar says ~3 “s/it” and it does seem to be taking ~2-3 seconds for the current sample # to increment by 1 iteration. This is 100x slower than a single call to eval_func (again only ~0.03 sec).
  3. When using the NUTS sampler, it’s usually ~500-1000x slower with ~20 s/it.

My questions:

  1. Can numpyro’s parameter drawing and likelihood evaluation really take 10-1000x longer than eval_func and dominate the time for a single iteration?
  2. Again my eval_func is already jitted – I assume numpyro is jitting under the hood when calling mcmc.run()? My numpyro model function suspiciously does not have a @jax.jit decorator. Should I be using jit_model_args=True when calling MCMC(AIES/NUTS(model),…)?
  3. How do I check whether numpyro is re-compiling itself internally during every iteration? (My eval_func() does NOT get re-compiled for different values of its input params_rand since the shape/dtype of params_rand is always the same.)
  4. I’m using 64-bit for jax with from jax import config; config.update("jax_enable_x64", True) but have NOT been using numpyro’s own enable_x64 – could this be causing the massive slowdown in numpyro?
  5. Assuming none of the above issues, does this imply that a “single iteration” of AIES/NUTS actually involves 100-1000 evaluations of eval_func() – how do I see how many times eval_func was called in a single numpyro iteration?
  6. Or is this just that my model is so complicated that numpyro’s likelihood evaluations really are 10-1000x more time-consuming than eval_func() itself? Is there an easy way for me to benchmark what part of numpyro itself is the slowest and might there be ways to speed those parts up?

I suspect the answer is a combination of the above. The fact that 2+ prior predictive samples takes ~1.3 sec ~ 50x longer than a single eval_func (~0.03 sec) or single prior predictive sample (~0.05 sec) means there is some internal numpyro ~10x overhead just from drawing params_rand. And the fact that the # sec per iteration for AIES (NUTS) is 100x (1000x) slower than a single eval_func suggests that they are calling eval_func 100-1000x for a “single iteration”. Assuming I’ve already maximally optimized eval_func (can’t make a single call much faster than ~0.01 sec), does this imply I’ve reached a floor of how fast numpyro can go? The only other optimizations would have to be messing with default AIES/NUTS sampler options…?

didn’t read your whole post but yes NUTS doesn’t tens or hundreds or even thousands of evaluations & gradient computations per MCMC iteration

1 Like

Thanks @martinjankowiak I just have a few quick questions left from my list at the bottom if/when you have time:

  1. Do I need to @jit my numpyro model or pass jit_model_args=True, or does numpyro automatically do all the jitting internally during the first call to mcmc.run()?
  2. Is there any easy way to check if numpyro is re-compiling itself during every iteration?
  3. Since I’m using jax’s from jax import config; config.update("jax_enable_x64", True) , do I also need numpyro’s enable_x64 or is it redundant/unnecessary?
  1. jit is done under the hood automatically
  2. i’m not sure but it is very unlikely that this is the case
  3. should be redundant/unnecessary

MCMC can be expensive. such is life.

1 Like

On point 2, print statements in jitted functions only execute during compilation, so that’s a fairly easy way to debug.

With respect to AIES, I would expect querying multiple chains to be slower than evaluating a single chain if you’ve saturated your hardware accelerator. Each chain queries the likelihood once per iteration.

1 Like

You can use utilities like log_density to compute the speed of the potential function and its gradient evaluation. I guess grad could be expensive for your model.