Hi @fehiepsi - my bad.
Here is an example that reproduces the problem:
import arviz as az
import celerite2.jax
import jax
import jax.numpy as jnp
import jaxopt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
numpyro.set_host_device_count(10)
class celerite2_gp:
def __init__(
self,
kernel_fn,
t,
yerr,
y,
max_mc_steps=10000,
num_chains=10,
num_warmup=1000,
converge_step=2500,
):
self.kernel_fn = kernel_fn
self.t = t
self.y = y
self.y_err = yerr
self.mean = None
self.max_mc_steps = max_mc_steps
self.num_chains = num_chains
self.num_warmup = num_warmup
self.converge_step = converge_step
def setup_gp(self, params, t, y=None):
kernel = self.kernel_fn(params=params)
self.gp = celerite2.jax.GaussianProcess(kernel)
self.gp.compute(t, check_sorted=False)
def negative_log_likelihood(self, params, t, y):
self.setup_gp(params, t, y)
return -self.gp.log_likelihood(y)
def numpyro_model(self, t, y=None, params=None):
self.setup_gp(params, t, y)
numpyro.sample("obs", self.gp.numpyro_dist(), obs=self.y)
def minimize(self):
initial_params = jnp.array([0.0, 0.0, 0.0, jnp.log(1.0), 0.0, jnp.log(1.0)])
solver = jaxopt.LBFGS(fun=self.negative_log_likelihood, maxiter=100)
opt_params, res = solver.run(init_params=initial_params, t=self.t, y=self.y)
self.setup_gp(opt_params, t=self.t)
def derive_posteriors(self):
kernel = NUTS(self.numpyro_model)
mcmc = MCMC(
kernel,
num_warmup=1000,
num_samples=1,
num_chains=self.num_chains,
chain_method="parallel",
)
rng = jax.random.PRNGKey(0)
mcmc.run(
rng,
t=self.t,
y=self.y,
)
state = mcmc.last_state
mcmc = MCMC(
kernel,
num_warmup=0,
num_samples=self.converge_step,
num_chains=self.num_chains,
chain_method="parallel",
progress_bar=False,
jit_model_args=True,
)
mcmc.post_warmup_state = state
print("Finished warm-up")
for iteration in range(int(self.max_mc_steps / self.converge_step)):
mcmc.run(mcmc.post_warmup_state.rng_key, t=self.t)
mcmc.post_warmup_state = mcmc.last_state
# do some convergence check and bail out of loop if met
self.posterior_samples = mcmc.get_samples()
az_summary = az.summary(az.from_numpyro(mcmc))
print(az_summary)
print(az_summary["r_hat"])
def example_kernel_fn(prior_sigma=2.0, params=None):
"""
An example kernel function that builds a sum of two SHO terms.
Args:
prior_sigma (float): Standard deviation of the normal priors.
Returns:
Kernel: A celerite2.jax kernel object.
"""
import celerite2.jax.terms as jax_terms
if params is not None:
mean, log_sigma1, log_rho1, log_tau, log_sigma2, log_rho2 = params
else:
rng_key = random.PRNGKey(34923)
mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
log_jitter = numpyro.sample("log_jitter", dist.Normal(0.0, prior_sigma))
log_sigma1 = numpyro.sample(
"log_sigma1", dist.Normal(0.0, prior_sigma), rng_key=rng_key
)
log_rho1 = numpyro.sample(
"log_rho1", dist.Normal(0.0, prior_sigma), rng_key=rng_key
)
log_tau = numpyro.sample(
"log_tau", dist.Normal(0.0, prior_sigma), rng_key=rng_key
)
log_sigma2 = numpyro.sample(
"log_sigma2", dist.Normal(0.0, prior_sigma), rng_key=rng_key
)
log_rho2 = numpyro.sample(
"log_rho2", dist.Normal(0.0, prior_sigma), rng_key=rng_key
)
term1 = jax_terms.SHOTerm(
sigma=jnp.exp(log_sigma1), rho=jnp.exp(log_rho1), tau=jnp.exp(log_tau)
)
term2 = jax_terms.SHOTerm(sigma=jnp.exp(log_sigma2), rho=jnp.exp(log_rho2), Q=0.25)
print("in model")
return term1 + term2
t = np.sort(
np.append(
np.random.uniform(0, 3.8, 57),
np.random.uniform(5.5, 20, 68),
)
)
y_train = 0.2 * (t - 5) + np.sin(3 * t + 0.1 * (t - 5) ** 2)
y_true = jnp.sin(t) + 0.1 * jax.random.normal(jax.random.PRNGKey(1), shape=t.shape)
t_test = jnp.linspace(0, 20, 500)
yerr = np.random.uniform(0.08, 0.22, len(t))
gp = celerite2_gp(kernel_fn=example_kernel_fn, t=t, yerr=yerr, y=y_train)
gp.minimize()
gp.derive_posteriors()
So if it is recompiling it will print “in model” multiple times after it enters the loop on iteration.