 # Example Kernel: ABC

I’m continuing my effort to build my numpyro skills, and now I’m working on a “vanilla” ABC kernel, as summarized in Algorithm 1 of Clarte et al, which I have included here for convenience

This sample method is the heart of my ABC kernel:

``````def sample(self, state, model_args, model_kwargs):
def while_condition_func(val):
distance, rng_key, proposal, n = val
return jnp.logical_and(distance > self._threshold,
n < self._max_attempts_per_sample)
def while_body_func(val):
distance, rng_key, proposal, n = val
rng_key, sample_key = random.split(rng_key)
proposal = self._predictive(sample_key, *model_args, **model_kwargs)
distance = self._summary_statistic(model_kwargs['obs'], proposal['x_pred'])  # FIXME: how do I get the observed vars?
# for that matter, how do I resample the values of these observed vars without building it into the modeL?
return (distance, rng_key, proposal, n+1)

distance, rng_key, proposal, n = \
lax.while_loop(while_condition_func,
while_body_func,
(jnp.inf,  # distance
state.rng_key,  # rng_key
state.z, # proposal
0 # iteration
))

proposal['theta'] = jnp.where(distance <= self._threshold, proposal['theta'], state.z['theta'])
return ABCState(proposal, rng_key)
``````

What I would like to know is a better way to get the part of a model that corresponds to "observed dataset x^\star" in the pseudocode, and a better way to generate the approximate dataset x^{(i)} in the proposal.

I would appreciate any feedback on how to make this more idiomatic numpyro, as well. Thanks for all the work on this package!

Oops, here is a gist with that code in context, which I meant to include.

Hi @abie, to get both prior samples and its predictive observations, you can do

``````def model(obs):
theta = numpyro.sample('theta', dist.Uniform(-10, 10))
x_obs = numpyro.sample('x_obs', dist.Normal(theta, 1), obs=obs)

model_args = (obs,)
model_kwargs = {}
# Here is self._predictive logic (rather than using Predictive)
with handlers.trace() as tr, with handlers.seed(rng_seed=...), \
with handlers.condition(condition_fn=lambda msg: msg["fn"](*msg["args"], **msg["kwargs"])):
model(*model_args, **model_kwargs)
proposal_and_obs = {name: site["value"] for name, site in tr.items() if site["type"] == "sample"}
``````

To get observed variables, you can do

``````with handlers.trace() as tr, with handlers.seed(rng_seed=0):
model(*model_args, **model_kwargs)

obs_vars = {name for name, site in tr.items() if site["type"] == "sample" and site["is_observed"]}
``````

Not important: maybe it’s better to have `summary_statistic` takes inputs as dictionaries (maps observation names to its values, which would work for models with multiple observations), rather than ndarrays.