# Hierarchical model with selection bias

Hi,

I wanted to formulate a hierarchical model that also produces observables that are selection biased. I am not sure what the best formulation to this problem is.

The details (for this toy model) are as follows.

1. There are `n_bins` bins and for each I draw a rate, called `rate_all` (from a uniform distribution).
2. Poisson sample the variable `n_all` given this rate.
3. For each of the bins, then draw `n_all` observations for each bin (so the number of observations in general is different in each bin). This new variable is called `loudness` and is drawn from a Normal distribution.
4. Depending on some kind of detected threshold (defined by the function `is_detected`), only a subset of these `loudness`s are detected.
5. Define an obs statement for the `loudness` parameter.

How do I implement this most elegantly?

Code:

``````# Meta variable
n_bin = 10
rate_all_max = 30  # Define the maximum value for rate_all

# Detection function
def is_detected(loudness, threshold=5.0):
return loudness > threshold

def model(observed_loudness=None, threshold=5.0):

# Draw n_all from a Poisson distribution with rate_all
with numpyro.plate("bins", n_bin):
# Sample rate_all from a uniform distribution
rate_all = numpyro.sample("rate_all", dist.Uniform(0, rate_all_max))
n_all = numpyro.sample("n_all", RightTruncatedPoisson(rate_all, high=50))

# Draw loudness from a normal distribution for each bin
loudness = []
for i in range(n_bin):
with numpyro.plate(f"loudness_{i}_plate", n_all[i]):
loudness.append(numpyro.sample(f"loudness_{i}", dist.Normal(5, 2)))

loudness = jnp.concatenate(loudness)

# Apply the detection function
detected_mask = is_detected(loudness, threshold)

# I am not sure how I efficiently apply the selection
# and write the obs statement

``````

with the custom distribution

``````class RightTruncatedPoisson(dist.Distribution):
"""
A truncated Poisson distribution.
:param numpy.ndarray high: high bound at which truncation happens
:param numpy.ndarray rate: rate of the Poisson distribution.
"""

arg_constraints = {
"high": dist.constraints.nonnegative_integer,
"rate": dist.constraints.positive,
}
has_enumerate_support = True

def __init__(self, rate=1.0, high=0, validate_args=None):
batch_shape = jax.lax.broadcast_shapes(jnp.shape(high), jnp.shape(rate))
self.high, self.rate = dist.util.promote_shapes(high, rate)
super().__init__(batch_shape, validate_args=validate_args)

def log_prob(self, value):
m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
log_p = jax.scipy.stats.poisson.logpmf(value, self.rate)
return jnp.where(value <= self.high, log_p - jnp.log(m), -jnp.inf)

def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
float_type = jnp.result_type(float)
minval = jnp.finfo(float_type).tiny
u = jax.random.uniform(key, shape, minval=minval)
# return self.icdf(u)        # Brute force
# return self.icdf_faster(u) # For faster sampling.
return self.icdf(u)  # Using `host_callback`
``````