Ooh, ooh, I think I can do this with the ABC Kernel I have been working on. It is a bit of a hack, but instead of testing if resampled data is close enough to real data, I can use the same code to test if the sum of the samples exceeds a threshold.
Here is what the code looks like:
def my_model():
with numpyro.plate('I', 4):
theta = numpyro.sample('theta', dist.Uniform(-10, 10))
def sum_exceeds_threshold(threshold, proposal):
return jnp.where(proposal['theta'].sum() > threshold, 0, jnp.inf)
def my_run(model):
rng_key = random.PRNGKey(12345)
sum_lower_bound = jnp.array(-1)
kernel = ABC(model,
data=sum_lower_bound, threshold=1,
summary_statistic=sum_exceeds_threshold,
max_attempts_per_sample=1_000)
mcmc = MCMC(kernel, num_warmup=0, num_samples=100, thinning=1)
mcmc.run(rng_key)
posterior_samples = mcmc.get_samples()
Here is an updated gist with the ABC kernel code this relies on (still work-in-progress as I work to understand numpyro better).