I wanted to formulate a hierarchical model that also produces observables that are selection biased. I am not sure what the best formulation to this problem is.
The details (for this toy model) are as follows.
- There are
bins and for each I draw a rate, calledrate_all
(from a uniform distribution). - Poisson sample the variable
given this rate. - For each of the bins, then draw
observations for each bin (so the number of observations in general is different in each bin). This new variable is calledloudness
and is drawn from a Normal distribution. - Depending on some kind of detected threshold (defined by the function
), only a subset of theseloudness
s are detected. - Define an obs statement for the
How do I implement this most elegantly?
# Meta variable
n_bin = 10
rate_all_max = 30 # Define the maximum value for rate_all
# Detection function
def is_detected(loudness, threshold=5.0):
return loudness > threshold
def model(observed_loudness=None, threshold=5.0):
# Draw n_all from a Poisson distribution with rate_all
with numpyro.plate("bins", n_bin):
# Sample rate_all from a uniform distribution
rate_all = numpyro.sample("rate_all", dist.Uniform(0, rate_all_max))
n_all = numpyro.sample("n_all", RightTruncatedPoisson(rate_all, high=50))
# Draw loudness from a normal distribution for each bin
loudness = []
for i in range(n_bin):
with numpyro.plate(f"loudness_{i}_plate", n_all[i]):
loudness.append(numpyro.sample(f"loudness_{i}", dist.Normal(5, 2)))
loudness = jnp.concatenate(loudness)
# Apply the detection function
detected_mask = is_detected(loudness, threshold)
# I am not sure how I efficiently apply the selection
# and write the obs statement
with the custom distribution
class RightTruncatedPoisson(dist.Distribution):
A truncated Poisson distribution.
:param numpy.ndarray high: high bound at which truncation happens
:param numpy.ndarray rate: rate of the Poisson distribution.
arg_constraints = {
"high": dist.constraints.nonnegative_integer,
"rate": dist.constraints.positive,
has_enumerate_support = True
def __init__(self, rate=1.0, high=0, validate_args=None):
batch_shape = jax.lax.broadcast_shapes(jnp.shape(high), jnp.shape(rate))
self.high, self.rate = dist.util.promote_shapes(high, rate)
super().__init__(batch_shape, validate_args=validate_args)
def log_prob(self, value):
m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
log_p = jax.scipy.stats.poisson.logpmf(value, self.rate)
return jnp.where(value <= self.high, log_p - jnp.log(m), -jnp.inf)
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
float_type = jnp.result_type(float)
minval = jnp.finfo(float_type).tiny
u = jax.random.uniform(key, shape, minval=minval)
# return self.icdf(u) # Brute force
# return self.icdf_faster(u) # For faster sampling.
return self.icdf(u) # Using `host_callback`