Convergence checks on MCMC while sampling without recompiling?

Is it possible to make checks on the convergence of an MCMC NUTS sampling phase and bail out if some criteria is hit?

At the minute I am doing something like:

        kernel = NUTS(self.numpyro_model)
        mcmc = MCMC(
            kernel,
            num_warmup=1000,
            num_samples=1,
            num_chains=self.num_chains,
            chain_method="parallel",
        )

        mcmc.run(
            rng,
            t=self.t,
            y=self.y,
        )

        converged = False
        iteration = 0
        state = mcmc.last_state
        mcmc = MCMC(
            kernel,
            num_warmup=0,
            num_samples=self.converge_step,
            num_chains=self.num_chains,
            chain_method="parallel",
            progress_bar=True,
        )
        mcmc.post_warmup_state = state
        for iteration in range(int(self.max_mc_steps / self.converge_step)):
            mcmc.run(mcmc.post_warmup_state.rng_key, t=self.t)
            mcmc.post_warmup_state = mcmc.last_state
            az_summary = az.summary(az.from_numpyro(mcmc))
           #do some check on convergence

But it seems to recompile for every step in the loop.

How can I prevent this recompiling?

I think recompiling happens in arviz. Could you remove the line az_summary = az.summary(az.from_numpyro(mcmc)) to double check? I think use az.from_dict or something to work with posterior samples directly, instead of the mcmc instance.

Ah!

Great - i’ll try that and report back.

Thanks @fehiepsi!

No - this doesn’t help. I thought it did, but the recompiling is still happening with just:

        for iteration in range(int(self.max_mc_steps / self.converge_step)):
            mcmc.run(mcmc.post_warmup_state.rng_key, t=self.t)
            mcmc.post_warmup_state = mcmc.last_state

So I guess each call of mcmc.run is causing the recompile. Is there any way around this, or a suggested approach to checking for convergence in MCMC on the fly?

Maybe set jit_model_args=True?

Yeah - I tried that, switched the progress bar off and added a print statement to the model, but it still seems to recompile. :smiling_face_with_tear:

I just run the test numpyro/test/infer/test_mcmc.py at 8e1d9b258a3e4be3076ee2ba81d832e45fba792e · pyro-ppl/numpyro · GitHub and it worked, so probably there is something off in your setting. Because you didn’t include a minimal reproducible code, it’s hard to tell.

Hi @fehiepsi - my bad.

Here is an example that reproduces the problem:

import arviz as az
import celerite2.jax
import jax
import jax.numpy as jnp
import jaxopt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS

numpyro.set_host_device_count(10)


class celerite2_gp:

    def __init__(
        self,
        kernel_fn,
        t,
        yerr,
        y,
        max_mc_steps=10000,
        num_chains=10,
        num_warmup=1000,
        converge_step=2500,
    ):

        self.kernel_fn = kernel_fn
        self.t = t
        self.y = y
        self.y_err = yerr
        self.mean = None
        self.max_mc_steps = max_mc_steps
        self.num_chains = num_chains
        self.num_warmup = num_warmup
        self.converge_step = converge_step

    def setup_gp(self, params, t, y=None):
        kernel = self.kernel_fn(params=params)
        self.gp = celerite2.jax.GaussianProcess(kernel)
        self.gp.compute(t, check_sorted=False)

    def negative_log_likelihood(self, params, t, y):
        self.setup_gp(params, t, y)
        return -self.gp.log_likelihood(y)

    def numpyro_model(self, t, y=None, params=None):
        self.setup_gp(params, t, y)
        numpyro.sample("obs", self.gp.numpyro_dist(), obs=self.y)

    def minimize(self):
        initial_params = jnp.array([0.0, 0.0, 0.0, jnp.log(1.0), 0.0, jnp.log(1.0)])
        solver = jaxopt.LBFGS(fun=self.negative_log_likelihood, maxiter=100)
        opt_params, res = solver.run(init_params=initial_params, t=self.t, y=self.y)

        self.setup_gp(opt_params, t=self.t)

    def derive_posteriors(self):

        kernel = NUTS(self.numpyro_model)
        mcmc = MCMC(
            kernel,
            num_warmup=1000,
            num_samples=1,
            num_chains=self.num_chains,
            chain_method="parallel",
        )

        rng = jax.random.PRNGKey(0)
        mcmc.run(
            rng,
            t=self.t,
            y=self.y,
        )
        state = mcmc.last_state
        mcmc = MCMC(
            kernel,
            num_warmup=0,
            num_samples=self.converge_step,
            num_chains=self.num_chains,
            chain_method="parallel",
            progress_bar=False,
            jit_model_args=True,
        )
        mcmc.post_warmup_state = state
        print("Finished warm-up")
        for iteration in range(int(self.max_mc_steps / self.converge_step)):
            mcmc.run(mcmc.post_warmup_state.rng_key, t=self.t)
            mcmc.post_warmup_state = mcmc.last_state
            # do some convergence check and bail out of loop if met
        self.posterior_samples = mcmc.get_samples()
        az_summary = az.summary(az.from_numpyro(mcmc))
        print(az_summary)
        print(az_summary["r_hat"])


def example_kernel_fn(prior_sigma=2.0, params=None):
    """
    An example kernel function that builds a sum of two SHO terms.
    Args:
        prior_sigma (float): Standard deviation of the normal priors.
    Returns:
        Kernel: A celerite2.jax kernel object.
    """
    import celerite2.jax.terms as jax_terms

    if params is not None:
        mean, log_sigma1, log_rho1, log_tau, log_sigma2, log_rho2 = params
    else:
        rng_key = random.PRNGKey(34923)
        mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
        log_jitter = numpyro.sample("log_jitter", dist.Normal(0.0, prior_sigma))
        log_sigma1 = numpyro.sample(
            "log_sigma1", dist.Normal(0.0, prior_sigma), rng_key=rng_key
        )
        log_rho1 = numpyro.sample(
            "log_rho1", dist.Normal(0.0, prior_sigma), rng_key=rng_key
        )
        log_tau = numpyro.sample(
            "log_tau", dist.Normal(0.0, prior_sigma), rng_key=rng_key
        )

        log_sigma2 = numpyro.sample(
            "log_sigma2", dist.Normal(0.0, prior_sigma), rng_key=rng_key
        )
        log_rho2 = numpyro.sample(
            "log_rho2", dist.Normal(0.0, prior_sigma), rng_key=rng_key
        )
    term1 = jax_terms.SHOTerm(
        sigma=jnp.exp(log_sigma1), rho=jnp.exp(log_rho1), tau=jnp.exp(log_tau)
    )
    term2 = jax_terms.SHOTerm(sigma=jnp.exp(log_sigma2), rho=jnp.exp(log_rho2), Q=0.25)
    print("in model")
    return term1 + term2


t = np.sort(
    np.append(
        np.random.uniform(0, 3.8, 57),
        np.random.uniform(5.5, 20, 68),
    )
)
y_train = 0.2 * (t - 5) + np.sin(3 * t + 0.1 * (t - 5) ** 2)
y_true = jnp.sin(t) + 0.1 * jax.random.normal(jax.random.PRNGKey(1), shape=t.shape)
t_test = jnp.linspace(0, 20, 500)
yerr = np.random.uniform(0.08, 0.22, len(t))
gp = celerite2_gp(kernel_fn=example_kernel_fn, t=t, yerr=yerr, y=y_train)

gp.minimize()
gp.derive_posteriors()

So if it is recompiling it will print “in model” multiple times after it enters the loop on iteration.

How about removing those self., those gp? Try to make things simple first (e.g. not use gp.numpyro_dist()) but another distribution like Normal. As I mentioned, the test passed so there might be some issues with your setting.