Error when sampling in MCMC

I have refactored some code that previously worked. On the face of it the two approaches should be equivalent, but I am now getting an error.

In the version that works I have a function that returns a celerite2 Term object:


@returns_type(Term)
def real_kernel_fn(
    params=None,
    fit=False,
    rng_key=None,
    bounds: dict = None,
    mean_model: MeanFunction = None,
):
    """Create a celerite2 kernel with parameters either fixed (for optimization)
    or sampled (for MCMC).

    Parameters
    ----------
    params : _type_
        list or array of initial parameter values (not in log-space)
    fit : bool, optional
        Use fixed param values (for optimization); if False, initialize for MCMC, by default False
    rng_key : int, optional
        PRNG key for Numpyro, by default None
    bounds : dict, optional
        (min, max) bounds for parameters, by default None

    Returns
    -------
    celerite2.jax.terms.Term
        jax_terms.RealTerm(a=a, c=c)
    """

    mean_value, params = _handle_mean(mean_model, params, fit, rng_key)
    if fit:

        a, c = params  # [mean_model.no_parameters :]

    else:

        log_a = numpyro.sample(
            "log_a",
            dist.Uniform(jnp.log(bounds["a"][0]), jnp.log(bounds["a"][1])),
            rng_key=rng_key,
        )
        log_c = numpyro.sample(
            "log_c",
            dist.Uniform(jnp.log(bounds["c"][0]), jnp.log(bounds["c"][1])),
            rng_key=rng_key,
        )

        a = jnp.exp(log_a)
        c = jnp.exp(log_c)

    return jax_terms.RealTerm(a=a, c=c), mean_value

This gets called by the numpyro model in gp.compute:

    def numpyro_model(
        self, t: jax.Array, params: jax.Array = None, fit: bool = False
    ) -> None:

        self.gp.compute(t, params=params, fit=fit)

        log_likelihood = self.gp.log_likelihood(self._lightcurve.y)
        numpyro.deterministic("log_likelihood", log_likelihood)
        numpyro.sample(
            "obs",
            self.gp.numpyro_dist(),
            obs=self._lightcurve.y,
            rng_key=self.rng_key,
        )

The new version abstracts the kernel creation so that it can accept a kernel specification, rather than having to create a function for every type of kernel. With the kernel created like:

    def _get_kernel(self, fit=True):

        rng_key = self.rng_key  # jax.random.PRNGKey(self.rng_key)
        terms = []

        for i, term in enumerate(self.kernel_spec.terms):
            # term_cls = term_class
            kwargs = {}

            for name, param_spec in term.parameters.items():
                full_name = f"term{i}_{name}"

                if fit or param_spec.fixed:
                    kwargs[name] = param_spec.value
                else:
                    dist_cls = param_spec.prior
                    val = numpyro.sample(
                        full_name, dist_cls(*param_spec.bounds), rng_key=rng_key
                    )
                    kwargs[name] = val

            terms.append(term.term_class(**kwargs))

        kernel = terms[0]
        for t in terms[1:]:
            kernel += t
        return kernel

In both cases they get called by mcmc.run


        kernel = NUTS(
            self.numpyro_model,
            adapt_step_size=True,
            dense_mass=True,
            init_strategy=init_to_value(values=fixed_params),
        )
        mcmc = MCMC(
            kernel,
            num_warmup=num_warmup,
            num_samples=1,
            num_chains=num_chains,
            chain_method= "parallel",
            jit_model_args=False,
            progress_bar=progress,
        )

        mcmc.run(
            self.rng_key,
            t=self._lightcurve.times,
        )

however with the kernel spec version the sampling returns an array of values, (val is the size of the number of chains - celerite2 expects a single value in the creation of the kernel Term) rather than a single value as is the case with the kernel function. I’m unsure as to why this is happening - as far as i can tell they should behave the same way. I am clearly misunderstaning how the sampling is being done.

Can anyone explain what may be going on here?

I think you can use the trace handler Effect Handlers — NumPyro documentation to inspect the model. It’s impossible to debug what’s going on in each of your implementations.

Found the problem - it wasnt in the numpyro sampling, but in how I initialise the parameters. :upside_down_face: