Example Kernel: ABC

I’m continuing my effort to build my numpyro skills, and now I’m working on a “vanilla” ABC kernel, as summarized in Algorithm 1 of Clarte et al, which I have included here for convenience

This sample method is the heart of my ABC kernel:

def sample(self, state, model_args, model_kwargs):
    def while_condition_func(val):
        distance, rng_key, proposal, n = val
        return jnp.logical_and(distance > self._threshold,
                               n < self._max_attempts_per_sample)
    def while_body_func(val):
        distance, rng_key, proposal, n = val
        rng_key, sample_key = random.split(rng_key)
        proposal = self._predictive(sample_key, *model_args, **model_kwargs)
        distance = self._summary_statistic(model_kwargs['obs'], proposal['x_pred'])  # FIXME: how do I get the observed vars?
        # for that matter, how do I resample the values of these observed vars without building it into the modeL?
        return (distance, rng_key, proposal, n+1)
    
    distance, rng_key, proposal, n = \
        lax.while_loop(while_condition_func,
                       while_body_func,
                       (jnp.inf,  # distance
                        state.rng_key,  # rng_key
                        state.z, # proposal
                        0 # iteration
                       ))
                   
    proposal['theta'] = jnp.where(distance <= self._threshold, proposal['theta'], state.z['theta'])
    return ABCState(proposal, rng_key)

What I would like to know is a better way to get the part of a model that corresponds to “observed dataset x^\star” in the pseudocode, and a better way to generate the approximate dataset x^{(i)} in the proposal.

I would appreciate any feedback on how to make this more idiomatic numpyro, as well. Thanks for all the work on this package!

Oops, here is a gist with that code in context, which I meant to include.

Hi @abie, to get both prior samples and its predictive observations, you can do

def model(obs):
    theta = numpyro.sample('theta', dist.Uniform(-10, 10))
    x_obs = numpyro.sample('x_obs', dist.Normal(theta, 1), obs=obs)

model_args = (obs,)
model_kwargs = {}
# Here is self._predictive logic (rather than using Predictive)
with handlers.trace() as tr, with handlers.seed(rng_seed=...), \
        with handlers.condition(condition_fn=lambda msg: msg["fn"](*msg["args"], **msg["kwargs"])):
    model(*model_args, **model_kwargs)
proposal_and_obs = {name: site["value"] for name, site in tr.items() if site["type"] == "sample"}

To get observed variables, you can do

with handlers.trace() as tr, with handlers.seed(rng_seed=0):
    model(*model_args, **model_kwargs)

obs_vars = {name for name, site in tr.items() if site["type"] == "sample" and site["is_observed"]}

Not important: maybe it’s better to have summary_statistic takes inputs as dictionaries (maps observation names to its values, which would work for models with multiple observations), rather than ndarrays.