Benchmark gradient and log prob of Numpyro model

What is the best way to measure/benchmark the time it takes to evaluate the log probability of a Numpyro model, as well as taking the gradients with respect to the latent paramters? I am mostly interested in the context of MCMC so I would like to benchmark to the log probability and gradient functions that Numpyro uses when doing inference with NUTS.

To have a specific example. How would I benchmark the eight_schools model from the Getting Started section?

import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

# Eight Schools example

def eight_schools(J, sigma, y=None):

    mu = numpyro.sample('mu', dist.Normal(0, 5))

    tau = numpyro.sample('tau', dist.HalfCauchy(5))

    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

After looking through the source code for a bit I came up with the following solution:

rng_key, rng_key_ = random.split(rng_key)
model_info = numpyro.infer.util.initialize_model(
    rng_key_,
    eight_schools,
    dynamic_args=True,
    model_args=[years],
    model_kwargs={"tavg": era5_arr, "climsims": cmip6_arr}
)
z = model_info.param_info.z
pe_fn = model_info.potential_fn(J, sigma, y=y)
value_and_grad_jit = jax.jit(jax.value_and_grad(pe_fn))
timings = []
for _ in range(100):
    start_time = time.time()
    value_and_grad_jit(z)[0].block_until_ready()
    timings.append(time.time() - start_time)

Does that look reasonable?

3 Likes

Yup, nicely done!

1 Like