Rejection sampling in Numpyro

How in Numpyro can I correctly implement rejection sampling?

In this application, I need the sum of four random variables to be above some lower bound.

I am not interested in a clever reparameterization- the real application is much more complex, I have explored/am exploring alternatives, and I need just an up down yes no on whether Numpyro is capable of rejection as a strategy.

I have a nearly working example of a single sample-rejection using just jax and a while_loop. Only nearly because I do not understand how to access the random number provided by the handlers.seed context.

with numpyro.handlers.seed(rng_seed=10):   #seed 5 gets us negatives
  upper_bound=2
  key = jax.random.PRNGKey(seed=0)
  def _body_fn(val):
    i, key, geom_acc = val
    key, key_u = random.split(key)
    u = random.normal(key_u)
    return i + 1, key, u
  def _cond_fn(val):
    i, _, geom_acc = val
    return geom_acc <= upper_bound
  ret=jax.lax.while_loop(cond_fun=_cond_fn,
                        body_fun=_body_fn,
                        init_val=(-1, key, 0.)
                        )
  ret

Dropping in a dist.normal results in an infinite loop currently, and previously threw errors about an uncaptured effect. In any case, I doubt this is the correct route and need someone to explain the right way to frame this.

If you just need up down yes no then the answer is yes.

This is not a NumPyro-specific question. You can find some similar patterns in jax.random samplers. Regarding getting random keys under seed handler, you can use key = numpyro.prng_key().

I guess you are not interested (because looking like you are not using the sampler to make a numpyro model, which does not require seed handler) but I post here for other readers who read this thread: after having a sampler (using rejected sampler) for a variable, we can create a numpyro distribution for it like in truncated distribution tutorial.

1 Like

Ooh, ooh, I think I can do this with the ABC Kernel I have been working on. It is a bit of a hack, but instead of testing if resampled data is close enough to real data, I can use the same code to test if the sum of the samples exceeds a threshold.

Here is what the code looks like:

def my_model():
    with numpyro.plate('I', 4):
        theta = numpyro.sample('theta', dist.Uniform(-10, 10))

def sum_exceeds_threshold(threshold, proposal):
    return jnp.where(proposal['theta'].sum() > threshold, 0, jnp.inf)
    
def my_run(model):
    rng_key = random.PRNGKey(12345)
    sum_lower_bound = jnp.array(-1)
    kernel = ABC(model,
                 data=sum_lower_bound, threshold=1,
                 summary_statistic=sum_exceeds_threshold,
                 max_attempts_per_sample=1_000)
    mcmc = MCMC(kernel, num_warmup=0, num_samples=100, thinning=1)
    mcmc.run(rng_key)
    posterior_samples = mcmc.get_samples()

Here is an updated gist with the ABC kernel code this relies on (still work-in-progress as I work to understand numpyro better).

1 Like

Thank you for the numpyro.prng_key() hint, spent ages looking for that.

I’m sorry I should have been clearer- I am writing a Numpyro model, I am familiar with the truncated distribution tutorials, I am specifically asking about what is the Numpyro-like way to draw a complicated sample and then re-roll if does not meet a potentially very complicated criteria. I am asking specifically because rejection sampling is usually froundupon, whenever it comes up someone tends to respond by directing to an existing or theoretically possible reparameterization of a simple distribution, and I could find no Numpyro examples.

So in my case I have a COVID-19 model with multiple pieces of information, at multiple scales of aggregation. Case counts by county-age-day from one database that serve as a lower bound on possible real infections. I also have case counts by county-day from a different better database that serves as a lower bound on the sum of those age groups. So I can either sample a truncated distribution at the county level, and then try to split it with some other truncated stick breaking distribution into ages, or I can sample truncated distributions at the age level and then sum them and reject if it’s below the known county total.

This is just one complication, there are others, hence the general interest in re-rolling if a more general complicated condition isn’t met.

My jax example will produce a single truncated normal. With a trivial rewrite it’ll produce samples whose sum is truncated as I describe. I imagine the Numpyro-like way to do it is in Abie’s direction, but he rejects at the very end and I’m hoping to reject at the sample stage so the rest of the model can carry on (rejection at a single draw is low dimensional and trivial, but rejection of the full model is very high dimensional and likely to almost never pass).

(Congrats on the 0.9.0 release today BTW just saw that and looks great)

Here’s a slightly more elaborate minimum working example to illustrate my dream of being able to re-roll if a condition is not met. The key idea is using brute force rejection in only a small part of the model to avoid solving a difficult and possibly still open mathematical problem every time I want to try something (in this case a bounded polytope).

Run a single time in the handlers.seed context, it works as intended, returning 4 draws whose sum meets a criteria.

Run as a model it returns the following Jax error. Probably because I am badly misusing the numpyro.sample machinery. Pyro recently added a rejector distribution. Is there a smarter way to do this?

def rejection_sampling_test():
#with numpyro.handlers.seed(rng_seed=1234): 
  lower_bound=0.5
  key = numpyro.prng_key()
  def _body_fn(val):
    i, key, infected_age1,infected_age2,infected_age3,infected_age4 = val
    key, key_u = random.split(key)
    infected_age1 = numpyro.sample('infected_age1', dist.Uniform(0, 0.25), rng_key=key_u)
    key, key_u = random.split(key_u)
    infected_age2 = numpyro.sample('infected_age2', dist.Uniform(0, 0.25), rng_key=key_u)
    key, key_u = random.split(key_u)
    infected_age3 = numpyro.sample('infected_age3', dist.Uniform(0, 0.25), rng_key=key_u)
    key, key_u = random.split(key_u)
    infected_age4 = numpyro.sample('infected_age4', dist.Uniform(0, 0.25), rng_key=key_u)
    return i + 1, key, infected_age1,infected_age2,infected_age3,infected_age4
  def _cond_fn(val):
    i, key, infected_age1,infected_age2,infected_age3,infected_age4 = val
    return infected_age1+infected_age2+infected_age3+infected_age4 < lower_bound
  how_many_tries, _, infected_age1,infected_age2,infected_age3,infected_age4=jax.lax.while_loop(cond_fun=_cond_fn, body_fun=_body_fn, init_val=(-1, key, 0.,0.,0.,0.) )
  county_infected=infected_age1+infected_age2+infected_age3+infected_age4
  #Use variables downstream like normal
  #numpyro.sample('infected_age1_observed', dist.BetaProportion(infected_age1,1000), obs=some_age_specific_input_data)
  #numpyro.sample('county_infected_observed', dist.BetaProportion(county_infected,1000), obs=some_county_input_data)

nuts_kernel = NUTS(rejection_sampling_test) #
mcmc = MCMC(nuts_kernel, num_warmup=30, num_samples=10, num_chains=1) #
rng_key = random.PRNGKey(0)
mcmc.run(rng_key)
mcmc.print_summary()

Error returned

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape () and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _body_fn at :5 traced for while_loop.

Hi @rexdouglass, I don’t think that you can use rejection sampling inside NUTS. NUTS draws samples based on variables’ densities (and for constrained supports, we need to define corresponding transforms). If the conditions are simple, you can define the support and the corresponding transforms for it, e.g.

class AcceptSupport(constraints.Constraint):
    event_dim = 1

    def __init__(self, lower_bound):
        self.lower_bound = lower_bound

    def __call__(self, x):
        return (x.sum(-1) > self.lower_bound).all() & constraints.interval(0, 0.25)(x).all(-1)

    def feasible_like(self, prototype):
        return jnp.ones_like(prototype) * lower_bound

accept_support = AcceptSupport(lower_bound)

class AcceptTransform(transforms.Transform):
    domain = constraints.real_vector
    codomain = accept_support
    # then define the transform, its inverse, and jacobian transform
    # (I'm not sure what's the correct math here

# then using
numpyro.sample("infected_age", dist.ImproperUniform(accept_support, batch_shape=(), event_shape=(4,))

But in general, I’m not sure what’s the suitable inference algorithm for your problems, assuming that you have a model p(x, y, z) where x is an observed variable, z is a latent variable, and y is a latent variable that has rejection sampler but no log density. I think @abie kernel might be suitable for your problem, something like

class CustomDistribution(dist.Distribution):
    support = constraints.independent(constraints.interval(0, 0.25), 1)
    def __init__(self, lower_bound):
        self.lower_bound = lower_bound
        super().__init__(self, batch_shape=(), event_shape=(4,))

    def sample(self, key, sample_shape=()):
        # implement the rejection sampler here

# then use this CustomDistribution in your model
1 Like

Thanks @fehiepsi, exactly what I needed to know. I’m comfortable I can code the CustomDistribution constraints and its sampling statement. What else is the bare minimum that it needs to have so that it can be used in a model like any other sampled variable? It’ll be my latent estimate of true infections, with priors and observables hung on it later downstream.

The excellent truncated distributions tutorial suggests that log_prob might be an absolute necessity, but possibly not the inverse or Jacobian.

implementing log_prob is the only essential ingredient needed if you want to stick a particular distribution in a model and use hmc to do inference

1 Like

Excellent.

For my sampler, I intend to draw 4 truncated normals in logit space, bounded below at the proportion of that age group with confirmed cases, and bounded above at the that age group’s share of the total population. I’ll then reject if their sum is below the proportion of the total population with known confirmed cases.

I’m not 100% sure how to set up the log_prob for that but that’ll be my next step.

i don’t really understand why you’re talking about sampling at all. hmc does not require samplers for primitive distributions that enter into the model. it only needs log_probs and gradients of log_probs. approximate posterior samples are generated by exploring the log density surface using hamiltonian dynamics. what inference algorithm do you want to use exactly?

1 Like

Just NUTS.

I’m sure this is a terminology problem on my part- all I mean to say is that I need to combine information and priors at multiple scales.

So I have priors and data on possible number of true covid infections by age group-day within a county, lower bounded by number of confirmed cases. I have different priors and data on possible number of true covid infections at the county-day level, lower bounded by number of confirmed cases from a different source. I then further have sero-prevalence estimates that are sometimes at states and other times irregular catchments of which counties are nested.

So I meant sampler to mean the “def sample” function of a custom distribution. That draws 4 samples for each age group, whose sum must be at least the known county level.

I started down the rejection sampling road because I’m more than willing to run the model a little longer if it avoids having to grind out the math for new custom distribution like a truncated Dirichlet, especially when I may go a different direction in the final model.

it may very well be a terminology thing but i still don’t understand. for a model with continuous latent variable x hmc requires that you can compute the log probability log p(x) (and its gradients) pointwise for every value of x in the support. i don’t see how rejection sampling enters the picture

Problem: How to draw truncated samples, whose sum also is truncated.

I asked whether this could be done quick and dirty with rejection, illustrated with a bad while_loop and the answer was no.

@fehiepsi then described an AcceptTransform solution that it is out of my reach, and a CustomDistribution solution which I might be able to do.

Which depends on a sample() function that I think I code with the jax while loop I first suggested, rejecting unless a criteria is met.

And you’ve explained also requires a log_prob() function which I can roughly imagine but would need to grind out next.

That’s where I think I’m currently at, I may have badly misunderstood.

If you want to use NUTS then we need to define supports for the “constrained” variables because NUTS algorithm works in real domains. When you have defined the support (the domain that satisfies your acceptance condition: e.g. vectors in (0, 0.25)^4 domain such that sum is greater than the lower bound) then you need to define the corresponding transform such that NUTS can use it to transform the constrained values in the support into a real vector. When this is available, you can use ImproperUniform to set a non-informative prior for that variable (no need to implement log_prob).

If you can derive the log probability by yourself, then you can use such custom distribution rather than the ImproperUniform. In this case, you still need to define the support and the transform as above.

Maybe you can use the following trick to avoid having to do custom things

# define the variables as usual, no acceptance condition
# then add the following factor
numpyro.factor("reject", jnp.where(accept_cond, 0., -jnp.inf))

NUTS don’t like that style but might be possible.

1 Like

Immensely appreciate the help and brainstorming, been stuck on this problem for a while.

In the spirit of quick algorithmic shortcuts, I tested the numpyro.factor suggestion and it works but produces divergences even with a big warm-up. I’ve had it beaten into me never to ignore divergences, and this is just a small test out of what’s 3k counties, 700+ days, with a random walk prior, etc, so experts are going to have to weigh in an tell me whether this is a viable option.

def dirty_factor_example():
#with numpyro.handlers.seed(rng_seed=1234):
    infected_age1 = numpyro.sample('infected_age1', dist.Uniform(0, 0.25))
    infected_age2 = numpyro.sample('infected_age2', dist.Uniform(0, 0.25))
    infected_age3 = numpyro.sample('infected_age3', dist.Uniform(0, 0.25))
    infected_age4 = numpyro.sample('infected_age4', dist.Uniform(0, 0.25))
    county_total= numpyro.deterministic("county_total",infected_age1+infected_age2+infected_age3+infected_age4)
    accept_cond=county_total>0.5
    numpyro.factor("reject", jnp.where(accept_cond, 0., -jnp.inf))  

nuts_kernel = NUTS(dirty_factor_example) #
mcmc = MCMC(nuts_kernel, num_warmup=10000, num_samples=500, num_chains=1) #
rng_key = random.PRNGKey(0)
mcmc.run(rng_key)
mcmc.print_summary()

                     mean       std    median      5.0%     95.0%     n_eff     r_hat
  infected_age1      0.15      0.06      0.16      0.06      0.24     85.97      1.01
  infected_age2      0.15      0.06      0.15      0.06      0.25    124.47      1.00
  infected_age3      0.16      0.06      0.18      0.05      0.25     85.37      1.00
  infected_age4      0.16      0.06      0.17      0.06      0.24     85.42      1.00

Number of divergences: 448

np.min(samples['county_total'])
DeviceArray(0.50024515, dtype=float32)

you might have better luck with a soft constraint instead of an infinite barrier but the degree to which you’re ok with that will of course depend on your goals

1 Like

Quickly trying different penalties for that barrier
[-inf] - runs quickly, works, many divergences
[-10000000] - runs quickly, still works, many divergences
[-1000, -10] - runs very slowly
[-1] - runs quickly, works, 0 divergences
[-0.1] - runs quickly, works, 0 divergences

1 Like

FYI, 1000 delta energy is hard coded in the hmc implementation. If a NUTS trajectory diverges too much, then we will report that that trajectory is diverging - this does not say that the proposed sample in that step is bad (the proposed sample is randomly drawn from the trajectory based on its density, so bad samples will unlikely be proposed). That explains many divergences that you faced. If you use a small penalty, then there might be some bad samples (those not satisfies the acceptance condition) in your posteriors. In that case, I guess you can just throw them away. :slight_smile: So you can either use large penalties and ignore the divergence diagnostic or use small penalty and throw away some bad samples. If you are seeing many bad samples in the posterior, then you can increase the penalty to a reasonable value. Just hack the sampler :))

I have a theory that for HMC methods like NUTS, a “penalty” with a non-infinite derivative will work better than the “barrier” approach you have used. Here is an alternative reject factor based on a pattern I recently used to test this theory in some of my own work:

numpyro.factor("reject", jnp.where(accept_cond, 0.,
                                   -10_000**2 * (county_total - 0.05)**2))

I would be interested to know how this goes for your setting.

Also, since we seem to be broadening the discussion beyond rejection sampling, have you considered explicitly representing the underestimation factor for the age-specific rates as an additional parameter in your model?

1 Like