I have refactored some code that previously worked. On the face of it the two approaches should be equivalent, but I am now getting an error.
In the version that works I have a function that returns a celerite2 Term object:
@returns_type(Term)
def real_kernel_fn(
params=None,
fit=False,
rng_key=None,
bounds: dict = None,
mean_model: MeanFunction = None,
):
"""Create a celerite2 kernel with parameters either fixed (for optimization)
or sampled (for MCMC).
Parameters
----------
params : _type_
list or array of initial parameter values (not in log-space)
fit : bool, optional
Use fixed param values (for optimization); if False, initialize for MCMC, by default False
rng_key : int, optional
PRNG key for Numpyro, by default None
bounds : dict, optional
(min, max) bounds for parameters, by default None
Returns
-------
celerite2.jax.terms.Term
jax_terms.RealTerm(a=a, c=c)
"""
mean_value, params = _handle_mean(mean_model, params, fit, rng_key)
if fit:
a, c = params # [mean_model.no_parameters :]
else:
log_a = numpyro.sample(
"log_a",
dist.Uniform(jnp.log(bounds["a"][0]), jnp.log(bounds["a"][1])),
rng_key=rng_key,
)
log_c = numpyro.sample(
"log_c",
dist.Uniform(jnp.log(bounds["c"][0]), jnp.log(bounds["c"][1])),
rng_key=rng_key,
)
a = jnp.exp(log_a)
c = jnp.exp(log_c)
return jax_terms.RealTerm(a=a, c=c), mean_value
This gets called by the numpyro model in gp.compute:
def numpyro_model(
self, t: jax.Array, params: jax.Array = None, fit: bool = False
) -> None:
self.gp.compute(t, params=params, fit=fit)
log_likelihood = self.gp.log_likelihood(self._lightcurve.y)
numpyro.deterministic("log_likelihood", log_likelihood)
numpyro.sample(
"obs",
self.gp.numpyro_dist(),
obs=self._lightcurve.y,
rng_key=self.rng_key,
)
The new version abstracts the kernel creation so that it can accept a kernel specification, rather than having to create a function for every type of kernel. With the kernel created like:
def _get_kernel(self, fit=True):
rng_key = self.rng_key # jax.random.PRNGKey(self.rng_key)
terms = []
for i, term in enumerate(self.kernel_spec.terms):
# term_cls = term_class
kwargs = {}
for name, param_spec in term.parameters.items():
full_name = f"term{i}_{name}"
if fit or param_spec.fixed:
kwargs[name] = param_spec.value
else:
dist_cls = param_spec.prior
val = numpyro.sample(
full_name, dist_cls(*param_spec.bounds), rng_key=rng_key
)
kwargs[name] = val
terms.append(term.term_class(**kwargs))
kernel = terms[0]
for t in terms[1:]:
kernel += t
return kernel
In both cases they get called by mcmc.run
kernel = NUTS(
self.numpyro_model,
adapt_step_size=True,
dense_mass=True,
init_strategy=init_to_value(values=fixed_params),
)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=1,
num_chains=num_chains,
chain_method= "parallel",
jit_model_args=False,
progress_bar=progress,
)
mcmc.run(
self.rng_key,
t=self._lightcurve.times,
)
however with the kernel spec version the sampling returns an array of values, (val is the size of the number of chains - celerite2 expects a single value in the creation of the kernel Term) rather than a single value as is the case with the kernel function. I’m unsure as to why this is happening - as far as i can tell they should behave the same way. I am clearly misunderstaning how the sampling is being done.
Can anyone explain what may be going on here?