Convergence check of MCMC without recompiling

I want to check the convergence of an MCMC run and exit if converged. However the model keeps being recompiled after the check.

Here is a dummy version of the code to show what I mean:

from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import celerite2
import celerite2.jax.terms as jax_terms
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from celerite2.jax.terms import Term
from jax import random
from numpyro.infer import MCMC, NUTS

# Random key
rng_key = jax.random.PRNGKey(0)


# Kernel parameter specs
@dataclass
class KernelParameterSpec:
    value: float
    fixed: bool = False
    prior: Optional[Callable[..., Any]] = None
    bounds: Optional[Tuple[float, float]] = None

    def __init__(
        self,
        value: float,
        fixed: bool = False,
        prior: Optional[Callable[..., Any]] = None,
        bounds: Optional[Tuple[float, float]] = None,
    ):
        self.value = value
        self.fixed = fixed
        self.prior = prior
        self.bounds = bounds


@dataclass
class KernelTermSpec:
    term_class: Type
    parameters: Dict[str, KernelParameterSpec]

    def __init__(self, term_class: Type, parameters: Dict[str, KernelParameterSpec]):
        self.term_class = term_class
        self.parameters = OrderedDict(parameters)


@dataclass
class KernelSpec:
    terms: List[KernelTermSpec]
    engine: str

    def __init__(self, terms: List[KernelTermSpec], engine: str):
        self.terms = terms
        self.engine = engine
        if issubclass(self.terms[0].term_class, Term):
            self.use_jax = True

    def update_params_from_array(self, array: Union[jax.Array, np.ndarray]) -> None:
        i = 0
        for term in self.terms:
            for name, param in term.parameters.items():
                param.value = jnp.array(array[i]) if self.use_jax else float(array[i])
                i += 1

    def get_param_array(self) -> Union[np.ndarray, jax.Array]:
        values = [
            param.value for term in self.terms for param in term.parameters.values()
        ]
        return jnp.array(values) if self.use_jax else np.array(values, dtype=np.float64)

    def get_bounds_array(self) -> Union[np.ndarray, jax.Array]:
        bounds = []
        for term in self.terms:
            for param in term.parameters.values():
                if not param.fixed:
                    if param.bounds is None:
                        raise ValueError("Non-fixed parameter is missing bounds.")
                    bounds.append(param.bounds)
        return jnp.array(bounds) if self.use_jax else np.array(bounds, dtype=np.float64)


# Generate the kernel
def get_kernel(kernel_spec):
    terms = []
    for i, term in enumerate(kernel_spec.terms):
        kwargs = {}
        for name, param_spec in term.parameters.items():
            full_name = f"term{i}_{name}"
            val = numpyro.sample(
                full_name, param_spec.prior(*param_spec.bounds), rng_key=rng_key
            )
            kwargs[name] = val
        terms.append(term.term_class(**kwargs))

    kernel = terms[0]
    for t in terms[1:]:
        kernel += t
    return kernel


# Probabilistic model
def model(t, yerr, y) -> None:
    kernel_spec = KernelSpec(
        engine="celerite2",
        terms=[
            KernelTermSpec(
                term_class=jax_terms.RealTerm,
                parameters={
                    "a": KernelParameterSpec(
                        value=variance_drw,
                        prior=dist.Uniform,
                        bounds=(-10, 50.0),
                    ),
                    "c": KernelParameterSpec(
                        value=w_bend,
                        prior=dist.Uniform,
                        bounds=(-10.0, 10.0),
                    ),
                },
            )
        ],
    )

    kernel = get_kernel(kernel_spec)
    gp = celerite2.jax.GaussianProcess(kernel)
    gp.compute(t, yerr=yerr, check_sorted=False)

    log_likelihood = gp.log_likelihood(y)
    print("in model")
    numpyro.deterministic("log_likelihood", log_likelihood)
    numpyro.sample("obs", dist.Normal(), obs=y, rng_key=rng_key)


# Run MCMC
def run_mcmc(model, times, fluxes, errors):
    kernel = NUTS(model, adapt_step_size=True, dense_mass=True)
    mcmc = MCMC(
        kernel,
        num_warmup=10,
        num_samples=1,
        num_chains=1,
        # chain_method="parallel",
        jit_model_args=True,
        progress_bar=False,
    )

    mcmc.run(rng_key, times, yerr=errors, y=fluxes)
    state = mcmc.last_state

    mcmc = MCMC(
        kernel,
        num_warmup=0,
        num_samples=100,
        num_chains=1,
        # chain_method="parallel",
        progress_bar=False,
        jit_model_args=True,
    )
    mcmc.post_warmup_state = state

    max_steps = 10000
    converge_steps = 1000

    for iteration in range(max_steps // converge_steps):
        mcmc.run(mcmc.post_warmup_state.rng_key, t=times, yerr=errors, y=fluxes)
        mcmc.post_warmup_state = mcmc.last_state

        samples = mcmc.get_samples(group_by_chain=True)
        if iteration == 0:
            all_samples = samples
        else:
            for key in samples:
                all_samples[key] = jnp.concatenate(
                    [all_samples[key], samples[key]], axis=0
                )

    return all_samples


# === Generate synthetic lightcurve using numpy ===
np.random.seed(42)
n_points = 500
mean_flux = 100
dt = 1.0
times = np.arange(0, n_points * dt, dt)

# Simulate a smooth trend + noise
flux_true = mean_flux + 5 * np.sin(2 * np.pi * times / 50)
errors = np.random.normal(1.0, 0.2, size=n_points)
noisy_flux = flux_true + np.random.normal(0, errors)

# Global hyperparameters
variance_drw = (mean_flux * 0.1) ** 2  # Variance for DRW
w_bend = 2 * np.pi / 20  # Angular frequency

# Run MCMC sampling
samples = run_mcmc(model, times=times, fluxes=noisy_flux, errors=errors)

Can anyone suggest a way to prevent the recompilation?

did you check your code with a simple model like

def model():
  numpyro.sample("x", dist.Normal(0,1))

?