I’m continuing my effort to build my numpyro skills, and now I’m working on a “vanilla” ABC kernel, as summarized in Algorithm 1 of Clarte et al, which I have included here for convenience
This sample method is the heart of my ABC kernel:
def sample(self, state, model_args, model_kwargs):
def while_condition_func(val):
distance, rng_key, proposal, n = val
return jnp.logical_and(distance > self._threshold,
n < self._max_attempts_per_sample)
def while_body_func(val):
distance, rng_key, proposal, n = val
rng_key, sample_key = random.split(rng_key)
proposal = self._predictive(sample_key, *model_args, **model_kwargs)
distance = self._summary_statistic(model_kwargs['obs'], proposal['x_pred']) # FIXME: how do I get the observed vars?
# for that matter, how do I resample the values of these observed vars without building it into the modeL?
return (distance, rng_key, proposal, n+1)
distance, rng_key, proposal, n = \
lax.while_loop(while_condition_func,
while_body_func,
(jnp.inf, # distance
state.rng_key, # rng_key
state.z, # proposal
0 # iteration
))
proposal['theta'] = jnp.where(distance <= self._threshold, proposal['theta'], state.z['theta'])
return ABCState(proposal, rng_key)
What I would like to know is a better way to get the part of a model that corresponds to “observed dataset x^\star” in the pseudocode, and a better way to generate the approximate dataset x^{(i)} in the proposal.
I would appreciate any feedback on how to make this more idiomatic numpyro, as well. Thanks for all the work on this package!