Here’s a slightly more elaborate minimum working example to illustrate my dream of being able to re-roll if a condition is not met. The key idea is using brute force rejection in only a small part of the model to avoid solving a difficult and possibly still open mathematical problem every time I want to try something (in this case a bounded polytope).
Run a single time in the handlers.seed context, it works as intended, returning 4 draws whose sum meets a criteria.
Run as a model it returns the following Jax error. Probably because I am badly misusing the numpyro.sample machinery. Pyro recently added a rejector distribution. Is there a smarter way to do this?
def rejection_sampling_test():
#with numpyro.handlers.seed(rng_seed=1234):
lower_bound=0.5
key = numpyro.prng_key()
def _body_fn(val):
i, key, infected_age1,infected_age2,infected_age3,infected_age4 = val
key, key_u = random.split(key)
infected_age1 = numpyro.sample('infected_age1', dist.Uniform(0, 0.25), rng_key=key_u)
key, key_u = random.split(key_u)
infected_age2 = numpyro.sample('infected_age2', dist.Uniform(0, 0.25), rng_key=key_u)
key, key_u = random.split(key_u)
infected_age3 = numpyro.sample('infected_age3', dist.Uniform(0, 0.25), rng_key=key_u)
key, key_u = random.split(key_u)
infected_age4 = numpyro.sample('infected_age4', dist.Uniform(0, 0.25), rng_key=key_u)
return i + 1, key, infected_age1,infected_age2,infected_age3,infected_age4
def _cond_fn(val):
i, key, infected_age1,infected_age2,infected_age3,infected_age4 = val
return infected_age1+infected_age2+infected_age3+infected_age4 < lower_bound
how_many_tries, _, infected_age1,infected_age2,infected_age3,infected_age4=jax.lax.while_loop(cond_fun=_cond_fn, body_fun=_body_fn, init_val=(-1, key, 0.,0.,0.,0.) )
county_infected=infected_age1+infected_age2+infected_age3+infected_age4
#Use variables downstream like normal
#numpyro.sample('infected_age1_observed', dist.BetaProportion(infected_age1,1000), obs=some_age_specific_input_data)
#numpyro.sample('county_infected_observed', dist.BetaProportion(county_infected,1000), obs=some_county_input_data)
nuts_kernel = NUTS(rejection_sampling_test) #
mcmc = MCMC(nuts_kernel, num_warmup=30, num_samples=10, num_chains=1) #
rng_key = random.PRNGKey(0)
mcmc.run(rng_key)
mcmc.print_summary()
Error returned
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape () and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _body_fn at :5 traced for while_loop.