Hierarchical model with selection bias

Hi,

I wanted to formulate a hierarchical model that also produces observables that are selection biased. I am not sure what the best formulation to this problem is.

The details (for this toy model) are as follows.

  1. There are n_bins bins and for each I draw a rate, called rate_all (from a uniform distribution).
  2. Poisson sample the variable n_all given this rate.
  3. For each of the bins, then draw n_all observations for each bin (so the number of observations in general is different in each bin). This new variable is called loudness and is drawn from a Normal distribution.
  4. Depending on some kind of detected threshold (defined by the function is_detected), only a subset of these loudnesss are detected.
  5. Define an obs statement for the loudness parameter.

How do I implement this most elegantly?

Code:

# Meta variable
n_bin = 10
rate_all_max = 30  # Define the maximum value for rate_all

# Detection function
def is_detected(loudness, threshold=5.0):
    return loudness > threshold

def model(observed_loudness=None, threshold=5.0):
    
    # Draw n_all from a Poisson distribution with rate_all
    with numpyro.plate("bins", n_bin):
        # Sample rate_all from a uniform distribution
        rate_all = numpyro.sample("rate_all", dist.Uniform(0, rate_all_max))
        n_all = numpyro.sample("n_all", RightTruncatedPoisson(rate_all, high=50))
    
    # Draw loudness from a normal distribution for each bin
    loudness = []
    for i in range(n_bin):
        with numpyro.plate(f"loudness_{i}_plate", n_all[i]):
            loudness.append(numpyro.sample(f"loudness_{i}", dist.Normal(5, 2)))
    
    loudness = jnp.concatenate(loudness)
    
    # Apply the detection function
    detected_mask = is_detected(loudness, threshold)
    
    # I am not sure how I efficiently apply the selection
    # and write the obs statement

 

with the custom distribution

class RightTruncatedPoisson(dist.Distribution):
    """
    A truncated Poisson distribution.
    :param numpy.ndarray high: high bound at which truncation happens
    :param numpy.ndarray rate: rate of the Poisson distribution.
    """

    arg_constraints = {
        "high": dist.constraints.nonnegative_integer,
        "rate": dist.constraints.positive,
    }
    has_enumerate_support = True

    def __init__(self, rate=1.0, high=0, validate_args=None):
        batch_shape = jax.lax.broadcast_shapes(jnp.shape(high), jnp.shape(rate))
        self.high, self.rate = dist.util.promote_shapes(high, rate)
        super().__init__(batch_shape, validate_args=validate_args)

    def log_prob(self, value):
        m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
        log_p = jax.scipy.stats.poisson.logpmf(value, self.rate)
        return jnp.where(value <= self.high, log_p - jnp.log(m), -jnp.inf)

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        float_type = jnp.result_type(float)
        minval = jnp.finfo(float_type).tiny
        u = jax.random.uniform(key, shape, minval=minval)
        # return self.icdf(u)        # Brute force
        # return self.icdf_faster(u) # For faster sampling.
        return self.icdf(u)  # Using `host_callback`