I want to check the convergence of an MCMC run and exit if converged. However the model keeps being recompiled after the check.
Here is a dummy version of the code to show what I mean:
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import celerite2
import celerite2.jax.terms as jax_terms
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from celerite2.jax.terms import Term
from jax import random
from numpyro.infer import MCMC, NUTS
# Random key
rng_key = jax.random.PRNGKey(0)
# Kernel parameter specs
@dataclass
class KernelParameterSpec:
value: float
fixed: bool = False
prior: Optional[Callable[..., Any]] = None
bounds: Optional[Tuple[float, float]] = None
def __init__(
self,
value: float,
fixed: bool = False,
prior: Optional[Callable[..., Any]] = None,
bounds: Optional[Tuple[float, float]] = None,
):
self.value = value
self.fixed = fixed
self.prior = prior
self.bounds = bounds
@dataclass
class KernelTermSpec:
term_class: Type
parameters: Dict[str, KernelParameterSpec]
def __init__(self, term_class: Type, parameters: Dict[str, KernelParameterSpec]):
self.term_class = term_class
self.parameters = OrderedDict(parameters)
@dataclass
class KernelSpec:
terms: List[KernelTermSpec]
engine: str
def __init__(self, terms: List[KernelTermSpec], engine: str):
self.terms = terms
self.engine = engine
if issubclass(self.terms[0].term_class, Term):
self.use_jax = True
def update_params_from_array(self, array: Union[jax.Array, np.ndarray]) -> None:
i = 0
for term in self.terms:
for name, param in term.parameters.items():
param.value = jnp.array(array[i]) if self.use_jax else float(array[i])
i += 1
def get_param_array(self) -> Union[np.ndarray, jax.Array]:
values = [
param.value for term in self.terms for param in term.parameters.values()
]
return jnp.array(values) if self.use_jax else np.array(values, dtype=np.float64)
def get_bounds_array(self) -> Union[np.ndarray, jax.Array]:
bounds = []
for term in self.terms:
for param in term.parameters.values():
if not param.fixed:
if param.bounds is None:
raise ValueError("Non-fixed parameter is missing bounds.")
bounds.append(param.bounds)
return jnp.array(bounds) if self.use_jax else np.array(bounds, dtype=np.float64)
# Generate the kernel
def get_kernel(kernel_spec):
terms = []
for i, term in enumerate(kernel_spec.terms):
kwargs = {}
for name, param_spec in term.parameters.items():
full_name = f"term{i}_{name}"
val = numpyro.sample(
full_name, param_spec.prior(*param_spec.bounds), rng_key=rng_key
)
kwargs[name] = val
terms.append(term.term_class(**kwargs))
kernel = terms[0]
for t in terms[1:]:
kernel += t
return kernel
# Probabilistic model
def model(t, yerr, y) -> None:
kernel_spec = KernelSpec(
engine="celerite2",
terms=[
KernelTermSpec(
term_class=jax_terms.RealTerm,
parameters={
"a": KernelParameterSpec(
value=variance_drw,
prior=dist.Uniform,
bounds=(-10, 50.0),
),
"c": KernelParameterSpec(
value=w_bend,
prior=dist.Uniform,
bounds=(-10.0, 10.0),
),
},
)
],
)
kernel = get_kernel(kernel_spec)
gp = celerite2.jax.GaussianProcess(kernel)
gp.compute(t, yerr=yerr, check_sorted=False)
log_likelihood = gp.log_likelihood(y)
print("in model")
numpyro.deterministic("log_likelihood", log_likelihood)
numpyro.sample("obs", dist.Normal(), obs=y, rng_key=rng_key)
# Run MCMC
def run_mcmc(model, times, fluxes, errors):
kernel = NUTS(model, adapt_step_size=True, dense_mass=True)
mcmc = MCMC(
kernel,
num_warmup=10,
num_samples=1,
num_chains=1,
# chain_method="parallel",
jit_model_args=True,
progress_bar=False,
)
mcmc.run(rng_key, times, yerr=errors, y=fluxes)
state = mcmc.last_state
mcmc = MCMC(
kernel,
num_warmup=0,
num_samples=100,
num_chains=1,
# chain_method="parallel",
progress_bar=False,
jit_model_args=True,
)
mcmc.post_warmup_state = state
max_steps = 10000
converge_steps = 1000
for iteration in range(max_steps // converge_steps):
mcmc.run(mcmc.post_warmup_state.rng_key, t=times, yerr=errors, y=fluxes)
mcmc.post_warmup_state = mcmc.last_state
samples = mcmc.get_samples(group_by_chain=True)
if iteration == 0:
all_samples = samples
else:
for key in samples:
all_samples[key] = jnp.concatenate(
[all_samples[key], samples[key]], axis=0
)
return all_samples
# === Generate synthetic lightcurve using numpy ===
np.random.seed(42)
n_points = 500
mean_flux = 100
dt = 1.0
times = np.arange(0, n_points * dt, dt)
# Simulate a smooth trend + noise
flux_true = mean_flux + 5 * np.sin(2 * np.pi * times / 50)
errors = np.random.normal(1.0, 0.2, size=n_points)
noisy_flux = flux_true + np.random.normal(0, errors)
# Global hyperparameters
variance_drw = (mean_flux * 0.1) ** 2 # Variance for DRW
w_bend = 2 * np.pi / 20 # Angular frequency
# Run MCMC sampling
samples = run_mcmc(model, times=times, fluxes=noisy_flux, errors=errors)
Can anyone suggest a way to prevent the recompilation?