How to time just the MCMC sampling, not including compile time?

Hello,

I get some samples from a posterior using the following code.

 kernel = NUTS(model)
 mcmc = MCMC(
   kernel,
   num_warmup = num_warmup,
   num_samples = num_samples,
   num_chains = num_chains
)
mcmc.run(rng_key, Z, M, sd, sd_prior, wts)
mcmc_results = mcmc.get_samples()

Right now I’m just wrapping this code with some time statements, but this includes the time to compile the model. It seems the compile time increases more noticeably when I sample in parallel. I want to time just the sampling procedure, not including the compile time.

Is this possible?

Thanks!

I guess you are using numpyro? You can use ._compile method to isolate the compiling time, as in this benchmark code.

1 Like

Thank you!

Hi, thanks for your help!

I’ve implemented the benchmark code into my model. However, I’m seeing the model compile three times, once when ._compile is run, and then again before the warmup steps, and before the sampling steps. I guess I’m confused about what is actually being compiled at ._compile, when the progress bar seems to show compilation twice more?

Or is the forced compilation just so we can get an estimate of the compile time, to then subtract from the sampling procedure?

I’m not sure why the compiling time is triggered 3 times in your case. Probably there is a regression in numpyro. Or your model inputs are not static… Could you try to run this test to see if there is delay between those phases?

What parameters would you like me to run this with? Running it with 8 chains and parallel (which is what I’m doing for my model) just immediately fails the test. Running with 8 chains and sequential I get:

sample: 100%|β–ˆ| 20/20 [00:00<00:00, 21.07it/s, 3 steps of size 1.22e+00. acc. pr
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 2108.27it/s, 1 steps of size 9.51e-01. acc. 
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 1872.75it/s, 1 steps of size 7.83e-01. acc. 
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 1970.17it/s, 1 steps of size 7.35e-01. acc. 
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 1900.97it/s, 1 steps of size 1.02e+00. acc. 
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 1866.91it/s, 7 steps of size 8.82e-01. acc. 
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 1889.20it/s, 3 steps of size 6.76e-01. acc. 
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 2200.46it/s, 3 steps of size 8.39e-01. acc. 
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 2073.51it/s, 7 steps of size 1.22e+00. acc. 
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 2028.59it/s, 1 steps of size 9.51e-01. acc. 
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 1961.33it/s, 15 steps of size 7.83e-01. acc.
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 1619.55it/s, 15 steps of size 7.35e-01. acc.
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 2067.48it/s, 3 steps of size 1.02e+00. acc. 
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 1734.47it/s, 9 steps of size 8.82e-01. acc. 
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 1946.76it/s, 7 steps of size 6.76e-01. acc. 
warmup: 100%|β–ˆ| 10/10 [00:00<00:00, 1571.49it/s, 1 steps of size 8.39e-01. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 2087.65it/s, 3 steps of size 1.22e+00. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 1859.18it/s, 1 steps of size 9.51e-01. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 1908.15it/s, 1 steps of size 7.83e-01. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 1735.69it/s, 1 steps of size 7.35e-01. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 1572.02it/s, 1 steps of size 1.02e+00. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 2007.23it/s, 7 steps of size 8.82e-01. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 1611.71it/s, 3 steps of size 6.76e-01. acc. 
sample: 100%|β–ˆ| 10/10 [00:00<00:00, 2008.77it/s, 3 steps of size 8.39e-01. acc. 
sample: 100%|β–ˆ| 20/20 [00:00<00:00, 22.55it/s, 7 steps of size 6.33e-01. acc. pr
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Input In [28], in <cell line: 1>()
----> 1 test_compile_warmup_run(8, 'sequential', True)

Input In [23], in test_compile_warmup_run(num_chains, chain_method, progress_bar)
     42 mcmc.run(rng_key)
     43 first_chain_samples = mcmc.get_samples()["x"]
---> 44 assert_allclose(actual_samples[:num_samples], first_chain_samples, atol=1e-5)

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/bcore/lib/python3.9/site-packages/numpy/testing/_private/utils.py:844, in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)
    840         err_msg += '\n' + '\n'.join(remarks)
    841         msg = build_err_msg([ox, oy], err_msg,
    842                             verbose=verbose, header=header,
    843                             names=('x', 'y'), precision=precision)
--> 844         raise AssertionError(msg)
    845 except ValueError:
    846     import traceback

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-05

Mismatched elements: 10 / 10 (100%)
Max absolute difference: 2.4036353
Max relative difference: 37.921455
 x: array([-1.116398, -1.116398,  0.089071,  0.841793,  0.886094,  0.648453,
        1.747691, -0.288125, -0.721614,  0.112518], dtype=float32)
 y: array([ 0.030237, -0.993181, -1.382091, -0.197607, -0.769425, -0.806987,
        0.407783,  0.505296, -0.859371, -2.291117], dtype=float32)

I cant see if there is any compiling going on here as it is so quick. It seems like there is no delay though.

This is my Bayesian neural network model:

def numpyro_model(Z, M, sd, sd_prior, wts):
    """Numpyro model to facilitate MCMC NUTS sampling. Mirrors the pytorch implementation.
    
    Parameters
    ----------
    Z: numpy array
        data array formatted as [X, y]
        
    M: int
        number of hidden layer node
        
    sd: float
        standard deviation of errors
    
    mu_prior: D numpy vector 
        prior mean (NOT IMPLEMENTED)
        
    sd_prior: float
        prior standard deviation
        
    wts: numpy vector
        coreset weights
    """
    # Extract X and y
    X = Z[:, :-1]
    y = Z[:, -1].reshape(-1, 1)
    
    # Get the input and output dimension, as well as the number of 
    N, D = X.shape
    P = y.shape[1]

    # Sample input layer weights (Normal prior)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D, M)), sd_prior * jnp.ones((D, M))))
    assert w1.shape == (D, M)
    
    # Sample input layer biases (Normal prior)
    b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((M)), sd_prior * jnp.ones((M))))
    assert b1.shape == (M, )
    
    # Compute first layer activations
    z1 = jnp.tanh((jnp.matmul(X, w1) + b1))
    assert z1.shape == (N, M)

    # Sample output layer weights (Normal prior)
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((M, P)), sd_prior * jnp.ones((M, P))))
    assert w2.shape == (M, P)
    
    # Sample output layer bias (Normal prior)
    b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((P)), sd_prior * jnp.ones((P))))
    assert b2.shape == (P, )
    
    # Compute output
    z2 = jnp.matmul(z1, w2) + b2
    assert z2.shape == (N, P)

    if y is not None:
        assert z2.shape == y.shape

    # Observe the data
    with numpyro.plate("data", N), numpyro.handlers.scale(scale = wts):
        numpyro.sample("y", dist.Normal(z2, sd).to_event(1), obs = y)

This is my sampling wrapper:

def sample_posterior(model, num_samples, num_warmup, num_chains, rng_key, Z, M, sd, sd_prior, wts, verbose = False):
    """Sample the posterior of the neural net numpyro model.
    
    Parameters
    ----------
    model: numpyro function
        numpyro neural net model
    
    num_samples: int > 0
        number of samples per chain
    
    num_warmup: int > 0
        number of warmup samples per chain
    
    num_chains: int > 0
        number of MCMC chains to run
    
    rng_key: key fromm jax.random.PRNGKey
        key controlling the random state
    
    Z: numpy array
        data array formatted as [X, y]
        
    M: int
        number of hidden layer nodes
        
    sd: float
        standard deviation of errors
    
    mu_prior: D numpy vector 
        prior mean (NOT IMPLEMENTED)
        
    sd_prior: float
        prior standard deviation
    
    wts: numpy vector
        coreset weights
    
    vebose: boolean
        flag to indicate whether timings and summaries should be printed
    Returns
    -------
    mcmc_results: dictionary
        dictionary of posterior samples
    """
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup = num_warmup,
        num_samples = num_samples,
        num_chains = num_chains
    )

    # Get compile time
    compile_time = time.time()
    mcmc._compile(rng_key, Z, M, sd, sd_prior, wts, extra_fields = ('num_steps',))
    compile_time = time.time() - compile_time
    if verbose:
        print('Compiling time:', compile_time, '\n')

    # Get warmup time
    warmup_time = time.time()
    mcmc.warmup(rng_key, Z, M, sd, sd_prior, wts, extra_fields = ('num_steps',))
    mcmc.num_samples = num_samples
    rng_key = mcmc._warmup_state.rng_key.copy()
    warmup_time = time.time() - warmup_time
    if verbose:
        print('Warmup time:', warmup_time, '\n')

    # Get sampling time
    sample_time = time.time()
    mcmc.run(rng_key, Z, M, sd, sd_prior, wts, extra_fields=('num_steps',))
    mcmc._last_state.rng_key.copy()
    sample_time = time.time() - sample_time
    if verbose:
        print('Sampling time:', sample_time, '\n')
        mcmc.print_summary()
    
    if verbose:
        num_leapfrogs = np.sum(mcmc.get_extra_fields()['num_steps'])
        print('Number of leapfrog steps:', num_leapfrogs)
        time_per_leapfrog = sample_time / num_leapfrogs
        print('Time per leapfrog step:', time_per_leapfrog)

        num_effs = [
            numpyro.diagnostics.effective_sample_size(device_get(v))
            for k, v in mcmc.get_samples(group_by_chain = True).items()
            ]

        num_effs = np.concatenate([np.reshape(x, -1) for x in num_effs])
        num_eff_mean = sum(num_effs) / len(num_effs)
        print('Average number of effective samples:', num_eff_mean)
        time_per_eff_sample = sample_time / num_eff_mean
        print('Time per effective sample:', time_per_eff_sample)

    mcmc_results = mcmc.get_samples()
    mcmc_time = sample_time - compile_time
    return mcmc_results, mcmc_time

How about defining your model as follows

def model():
    ...

rather than

def model(foo, bar):
    ...

So that you don’t need to provide inputs to compile/warmup/run. If compiling is triggered, then it is likely an numpyro issue. If compiling is not triggered, then some of your inputs are not static - you can add one by one to the model signature to see what it is.

What do you mean by static inputs? As in the model itself is altering the input data?

I’m not sure how any of the inputs would be changing as they are all fixed values/arrays, except the rng key maybe? I initialise that by just doing

rng_key = jax.random.PRNGKey(randrange(1000))

I’ve removed all the inputs and reran but I still get compile happening 3 times:

N = 500
X = np.linspace(-10, 10, N).reshape(-1, 1)
generating_mean = X + 2.5*np.sin(X) + 5*np.exp(-(X - 1)**2) - 5*np.exp(-(X - -3)**2) 

sd = 2
y = np.random.default_rng().normal(loc = generating_mean, scale = sd).reshape(-1, 1)
Z = np.hstack((X, y))

M = 25
sd_prior = 1
wts = np.ones(Z.shape[0])
def numpyro_model():    
    # Get the input and output dimension, as well as the number of 
    N, D = X.shape
    P = y.shape[1]

    # Sample input layer weights (Normal prior)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D, M)), sd_prior * jnp.ones((D, M))))
    assert w1.shape == (D, M)
    
    # Sample input layer biases (Normal prior)
    b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((M)), sd_prior * jnp.ones((M))))
    assert b1.shape == (M, )
    
    # Compute first layer activations
    z1 = jnp.tanh((jnp.matmul(X, w1) + b1))
    assert z1.shape == (N, M)

    # Sample output layer weights (Normal prior)
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((M, P)), sd_prior * jnp.ones((M, P))))
    assert w2.shape == (M, P)
    
    # Sample output layer bias (Normal prior)
    b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((P)), sd_prior * jnp.ones((P))))
    assert b2.shape == (P, )
    
    # Compute output
    z2 = jnp.matmul(z1, w2) + b2
    assert z2.shape == (N, P)

    if y is not None:
        assert z2.shape == y.shape

    # Observe the data
    with numpyro.plate("data", N), numpyro.handlers.scale(scale = wts):
        numpyro.sample("y", dist.Normal(z2, sd).to_event(1), obs = y)

def sample_posterior(rng_key, verbose = False):
    kernel = NUTS(numpyro_model)
    mcmc = MCMC(
        kernel,
        num_warmup = num_warmup,
        num_samples = num_samples,
        num_chains = num_chains
    )

    ## Get compile time
    compile_time = time.time()
    mcmc._compile(rng_key, extra_fields = ('num_steps',))
    compile_time = time.time() - compile_time
    if verbose:
        print('Compiling time:', compile_time, '\n')

    # Get warmup time
    warmup_time = time.time()
    mcmc.warmup(rng_key, extra_fields = ('num_steps',))
    mcmc.num_samples = num_samples
    rng_key = mcmc._warmup_state.rng_key.copy()
    warmup_time = time.time() - warmup_time
    if verbose:
        print('Warmup time:', warmup_time, '\n')

    # Get sampling time
    sample_time = time.time()
    mcmc.run(rng_key, extra_fields=('num_steps',))
    mcmc._last_state.rng_key.copy()
    sample_time = time.time() - sample_time
    if verbose:
        print('Sampling time:', sample_time, '\n')
        mcmc.print_summary()

    mcmc_results = mcmc.get_samples()
    #mcmc_time = sample_time - compile_time
    return mcmc_results, sample_time

full_mcmc_results, full_sample_time = sample_posterior(rng_key, verbose = False)

I just quickly test your code in colab with latest numpyro version. I got something like

Compiling time: 4.722844362258911 

warmup: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:00<00:00, 123.97it/s, 5 steps of size 2.29e-02. acc. prob=0.75]  
Warmup time: 0.8248724937438965 

sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:00<00:00, 304.83it/s, 255 steps of size 2.29e-02. acc. prob=0.96]Sampling time: 0.34980177879333496 

which indicates that compiling is not triggered in warmup and sampling phases.

Re static: it depends on the context but here, I meant some inputs that cause jax recompile a jitted function. Sometime, using integers like 0 or 1 will become weak type under the first iteration and concrete type in later iteration.

Weird, it still shows that is compling three times on my end. Could you share the exact code you ran?

If I run the chains sequentially the compiling issue is fixed, even using my original code. It is only when I try run the chains in parallel do I see the compiling happening three times. I am making it run in parallel just using

numpyro.set_host_device_count(8)

I see. Currently, we can’t separate out compiling time for parallel map (see this issue).

Ah, that makes sense then. I’ll keep an eye on the issue.

Thanks for all your help!