Convergence check of MCMC without recompiling

I want to check the convergence of an MCMC run and exit if converged. However the model keeps being recompiled after the check.

Here is a dummy version of the code to show what I mean:

from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import celerite2
import celerite2.jax.terms as jax_terms
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from celerite2.jax.terms import Term
from jax import random
from numpyro.infer import MCMC, NUTS

# Random key
rng_key = jax.random.PRNGKey(0)


# Kernel parameter specs
@dataclass
class KernelParameterSpec:
    value: float
    fixed: bool = False
    prior: Optional[Callable[..., Any]] = None
    bounds: Optional[Tuple[float, float]] = None

    def __init__(
        self,
        value: float,
        fixed: bool = False,
        prior: Optional[Callable[..., Any]] = None,
        bounds: Optional[Tuple[float, float]] = None,
    ):
        self.value = value
        self.fixed = fixed
        self.prior = prior
        self.bounds = bounds


@dataclass
class KernelTermSpec:
    term_class: Type
    parameters: Dict[str, KernelParameterSpec]

    def __init__(self, term_class: Type, parameters: Dict[str, KernelParameterSpec]):
        self.term_class = term_class
        self.parameters = OrderedDict(parameters)


@dataclass
class KernelSpec:
    terms: List[KernelTermSpec]
    engine: str

    def __init__(self, terms: List[KernelTermSpec], engine: str):
        self.terms = terms
        self.engine = engine
        if issubclass(self.terms[0].term_class, Term):
            self.use_jax = True

    def update_params_from_array(self, array: Union[jax.Array, np.ndarray]) -> None:
        i = 0
        for term in self.terms:
            for name, param in term.parameters.items():
                param.value = jnp.array(array[i]) if self.use_jax else float(array[i])
                i += 1

    def get_param_array(self) -> Union[np.ndarray, jax.Array]:
        values = [
            param.value for term in self.terms for param in term.parameters.values()
        ]
        return jnp.array(values) if self.use_jax else np.array(values, dtype=np.float64)

    def get_bounds_array(self) -> Union[np.ndarray, jax.Array]:
        bounds = []
        for term in self.terms:
            for param in term.parameters.values():
                if not param.fixed:
                    if param.bounds is None:
                        raise ValueError("Non-fixed parameter is missing bounds.")
                    bounds.append(param.bounds)
        return jnp.array(bounds) if self.use_jax else np.array(bounds, dtype=np.float64)


# Generate the kernel
def get_kernel(kernel_spec):
    terms = []
    for i, term in enumerate(kernel_spec.terms):
        kwargs = {}
        for name, param_spec in term.parameters.items():
            full_name = f"term{i}_{name}"
            val = numpyro.sample(
                full_name, param_spec.prior(*param_spec.bounds), rng_key=rng_key
            )
            kwargs[name] = val
        terms.append(term.term_class(**kwargs))

    kernel = terms[0]
    for t in terms[1:]:
        kernel += t
    return kernel


# Probabilistic model
def model(t, yerr, y) -> None:
    kernel_spec = KernelSpec(
        engine="celerite2",
        terms=[
            KernelTermSpec(
                term_class=jax_terms.RealTerm,
                parameters={
                    "a": KernelParameterSpec(
                        value=variance_drw,
                        prior=dist.Uniform,
                        bounds=(-10, 50.0),
                    ),
                    "c": KernelParameterSpec(
                        value=w_bend,
                        prior=dist.Uniform,
                        bounds=(-10.0, 10.0),
                    ),
                },
            )
        ],
    )

    kernel = get_kernel(kernel_spec)
    gp = celerite2.jax.GaussianProcess(kernel)
    gp.compute(t, yerr=yerr, check_sorted=False)

    log_likelihood = gp.log_likelihood(y)
    print("in model")
    numpyro.deterministic("log_likelihood", log_likelihood)
    numpyro.sample("obs", dist.Normal(), obs=y, rng_key=rng_key)


# Run MCMC
def run_mcmc(model, times, fluxes, errors):
    kernel = NUTS(model, adapt_step_size=True, dense_mass=True)
    mcmc = MCMC(
        kernel,
        num_warmup=10,
        num_samples=1,
        num_chains=1,
        # chain_method="parallel",
        jit_model_args=True,
        progress_bar=False,
    )

    mcmc.run(rng_key, times, yerr=errors, y=fluxes)
    state = mcmc.last_state

    mcmc = MCMC(
        kernel,
        num_warmup=0,
        num_samples=100,
        num_chains=1,
        # chain_method="parallel",
        progress_bar=False,
        jit_model_args=True,
    )
    mcmc.post_warmup_state = state

    max_steps = 10000
    converge_steps = 1000

    for iteration in range(max_steps // converge_steps):
        mcmc.run(mcmc.post_warmup_state.rng_key, t=times, yerr=errors, y=fluxes)
        mcmc.post_warmup_state = mcmc.last_state

        samples = mcmc.get_samples(group_by_chain=True)
        if iteration == 0:
            all_samples = samples
        else:
            for key in samples:
                all_samples[key] = jnp.concatenate(
                    [all_samples[key], samples[key]], axis=0
                )

    return all_samples


# === Generate synthetic lightcurve using numpy ===
np.random.seed(42)
n_points = 500
mean_flux = 100
dt = 1.0
times = np.arange(0, n_points * dt, dt)

# Simulate a smooth trend + noise
flux_true = mean_flux + 5 * np.sin(2 * np.pi * times / 50)
errors = np.random.normal(1.0, 0.2, size=n_points)
noisy_flux = flux_true + np.random.normal(0, errors)

# Global hyperparameters
variance_drw = (mean_flux * 0.1) ** 2  # Variance for DRW
w_bend = 2 * np.pi / 20  # Angular frequency

# Run MCMC sampling
samples = run_mcmc(model, times=times, fluxes=noisy_flux, errors=errors)

Can anyone suggest a way to prevent the recompilation?

did you check your code with a simple model like

def model():
  numpyro.sample("x", dist.Normal(0,1))

?

Hi @fehiepsi - it still does the recompilation even with the simple model:

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

# Random key
rng_key = jax.random.PRNGKey(0)


def model() -> None:

    numpyro.sample("x", dist.Normal(0, 1))
    print("in model")


# Run MCMC
def run_mcmc(model):
    kernel = NUTS(model, adapt_step_size=True, dense_mass=True)
    mcmc = MCMC(
        kernel,
        num_warmup=10,
        num_samples=1,
        num_chains=1,
        jit_model_args=True,
        progress_bar=False,
    )

    mcmc.run(rng_key)
    state = mcmc.last_state

    mcmc = MCMC(
        kernel,
        num_warmup=0,
        num_samples=100,
        num_chains=1,
        progress_bar=False,
        jit_model_args=True,
    )
    mcmc.post_warmup_state = state

    max_steps = 10000
    converge_steps = 1000

    for iteration in range(max_steps // converge_steps):
        mcmc.run(mcmc.post_warmup_state.rng_key)
        mcmc.post_warmup_state = mcmc.last_state

        samples = mcmc.get_samples(group_by_chain=True)
        if iteration == 0:
            all_samples = samples
        else:
            for key in samples:
                all_samples[key] = jnp.concatenate(
                    [all_samples[key], samples[key]], axis=0
                )

    return all_samples


# Run MCMC sampling
samples = run_mcmc(model)

could you double check? I tried the code on colab but didnt see the issue.

Interesting. I’ve just checked and the “in model” gets printed multiple times, so I assume it is recompiling each time.

I thought maybe it might be because I was on a Mac M2, but just check on an AMD chip and the same thing happens:

python kernel_spec2.py 
in model
in model
in model
in model

Here are the relevant versions from the env:

Package                   Version
------------------------- --------------
numpy                     2.1.2
numpyro                   0.18.0
jax                       0.6.0
jaxlib                    0.6.0

I see 4 print statements so it’s unlikely that they come from the iterations. To initialize the MCMC chain, we need to inspect the model several times (checking for transforms, checking for grad, etc.)

1 Like

Hi @fehiepsi

It looks like having numpyro.deterministic() line in the model is what causes the recompilation. If it is uncommented in the following then “in model” is printed multiple times.

from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import celerite2
import celerite2.jax.terms as jax_terms
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from celerite2.jax.terms import Term
from jax import random
from numpyro.infer import MCMC, NUTS

rng_key = jax.random.PRNGKey(0)
np.random.seed(42)
n_points = 500
mean_flux = 100
dt = 1.0
times = np.arange(0, n_points * dt, dt)

# Simulate a smooth trend + noise
flux_true = mean_flux + 5 * np.sin(2 * np.pi * times / 50)
errors = np.random.normal(1.0, 0.2, size=n_points)
noisy_flux = flux_true + np.random.normal(0, errors)

variance_drw = (mean_flux * 0.1) ** 2  # Variance for DRW
w_bend = 2 * np.pi / 20  # Angular frequency


class TestRecomp:

    def class_model(self, t, yerr, y) -> None:

        kernel = jax_terms.RealTerm(a=variance_drw, c=w_bend)

        gp = celerite2.jax.GaussianProcess(kernel)
        gp.compute(t, yerr=yerr, check_sorted=False)

        log_likelihood = gp.log_likelihood(y)
        print("in model")
        # numpyro.deterministic("log_likelihood", log_likelihood)
        numpyro.sample("obs", dist.Normal(), obs=y, rng_key=rng_key)
        numpyro.sample("x", dist.Normal(0, 1))

    # Run MCMC
    def run_mcmc(self, times, fluxes, errors):
        kernel = NUTS(self.class_model, adapt_step_size=True, dense_mass=True)
        mcmc = MCMC(
            kernel,
            num_warmup=10,
            num_samples=1,
            num_chains=1,
            jit_model_args=True,
            progress_bar=False,
        )

        mcmc.run(rng_key, times, yerr=errors, y=fluxes)
        state = mcmc.last_state

        mcmc = MCMC(
            kernel,
            num_warmup=0,
            num_samples=100,
            num_chains=1,
            progress_bar=False,
            jit_model_args=True,
        )
        mcmc.post_warmup_state = state

        max_steps = 10000
        converge_steps = 100

        for iteration in range(max_steps // converge_steps):
            mcmc.run(mcmc.post_warmup_state.rng_key, times, yerr=errors, y=fluxes)
            mcmc.post_warmup_state = mcmc.last_state

            samples = mcmc.get_samples(group_by_chain=True)
            if iteration == 0:
                all_samples = samples
            else:
                for key in samples:
                    all_samples[key] = jnp.concatenate(
                        [all_samples[key], samples[key]], axis=0
                    )

        return all_samples


mcmc_runner = TestRecomp()


mcmc_runner.run_mcmc(times=times, fluxes=noisy_flux, errors=errors)

Is there a way of avoiding this?

Are you asking to reduce the number of times we trace the model? As mentioned before, we need to travel several times to inspect the model. If there are determistic sites, we need extra tracing to collect them. If you dont want to collect deterministic sites, you can close them through a model argument like:

if add_deterministic:
    numpyro.deterministic(...)

then do mcmc as normal. Finally, use Predictive to collect deterministic sites.

Or are you seeing adding deterministic sites trigger recompiling in “every” iteration in your code?

Thanks @fehiepsi,
Though It looks like adding deterministic sites triggers recompiling in every iteration in the code.

If I print during the iterations "in model gets called every time:

in model

Iteration 1/100
in model
in model
in model
Iteration 2/100
in model
Iteration 3/100
in model
Iteration 4/100
in model
Iteration 5/100
in model
Iteration 6/100
in model
Iteration 7/100
in model
Iteration 8/100
in model
Iteration 9/100

If I get rid of the deterministic it doesnt:

in model
Iteration 1/100
in model
Iteration 2/100
Iteration 3/100
Iteration 4/100
Iteration 5/100
Iteration 6/100
Iteration 7/100
Iteration 8/100
Iteration 9/100
Iteration 10/100
Iteration 11/100

Is this not what is expected?

yeah, it’s not expected. Does it also happen with the simple model above (with deterministic site added)?

As mentioned above, you can add a flag to the model to defer the computation of deterministic sites to later.

Hi @fehiepsi

Yeah - it happens with the simple model too. How would I use the flag in practice to defer the computation - it isnt clear to me!

Thanks!

you can use

def model(add_deterministic=False):
    if ...

and proceed like in my comment above. A model is just a function, you can add logic to disable/enable behaviors that you want.