Rejection sampling in Numpyro

Ooh, ooh, I think I can do this with the ABC Kernel I have been working on. It is a bit of a hack, but instead of testing if resampled data is close enough to real data, I can use the same code to test if the sum of the samples exceeds a threshold.

Here is what the code looks like:

def my_model():
    with numpyro.plate('I', 4):
        theta = numpyro.sample('theta', dist.Uniform(-10, 10))

def sum_exceeds_threshold(threshold, proposal):
    return jnp.where(proposal['theta'].sum() > threshold, 0, jnp.inf)
    
def my_run(model):
    rng_key = random.PRNGKey(12345)
    sum_lower_bound = jnp.array(-1)
    kernel = ABC(model,
                 data=sum_lower_bound, threshold=1,
                 summary_statistic=sum_exceeds_threshold,
                 max_attempts_per_sample=1_000)
    mcmc = MCMC(kernel, num_warmup=0, num_samples=100, thinning=1)
    mcmc.run(rng_key)
    posterior_samples = mcmc.get_samples()

Here is an updated gist with the ABC kernel code this relies on (still work-in-progress as I work to understand numpyro better).

1 Like