Hi,
I wanted to formulate a hierarchical model that also produces observables that are selection biased. I am not sure what the best formulation to this problem is.
The details (for this toy model) are as follows.
- There are
n_bins
bins and for each I draw a rate, calledrate_all
(from a uniform distribution). - Poisson sample the variable
n_all
given this rate. - For each of the bins, then draw
n_all
observations for each bin (so the number of observations in general is different in each bin). This new variable is calledloudness
and is drawn from a Normal distribution. - Depending on some kind of detected threshold (defined by the function
is_detected
), only a subset of theseloudness
s are detected. - Define an obs statement for the
loudness
parameter.
How do I implement this most elegantly?
Code:
# Meta variable
n_bin = 10
rate_all_max = 30 # Define the maximum value for rate_all
# Detection function
def is_detected(loudness, threshold=5.0):
return loudness > threshold
def model(observed_loudness=None, threshold=5.0):
# Draw n_all from a Poisson distribution with rate_all
with numpyro.plate("bins", n_bin):
# Sample rate_all from a uniform distribution
rate_all = numpyro.sample("rate_all", dist.Uniform(0, rate_all_max))
n_all = numpyro.sample("n_all", RightTruncatedPoisson(rate_all, high=50))
# Draw loudness from a normal distribution for each bin
loudness = []
for i in range(n_bin):
with numpyro.plate(f"loudness_{i}_plate", n_all[i]):
loudness.append(numpyro.sample(f"loudness_{i}", dist.Normal(5, 2)))
loudness = jnp.concatenate(loudness)
# Apply the detection function
detected_mask = is_detected(loudness, threshold)
# I am not sure how I efficiently apply the selection
# and write the obs statement
with the custom distribution
class RightTruncatedPoisson(dist.Distribution):
"""
A truncated Poisson distribution.
:param numpy.ndarray high: high bound at which truncation happens
:param numpy.ndarray rate: rate of the Poisson distribution.
"""
arg_constraints = {
"high": dist.constraints.nonnegative_integer,
"rate": dist.constraints.positive,
}
has_enumerate_support = True
def __init__(self, rate=1.0, high=0, validate_args=None):
batch_shape = jax.lax.broadcast_shapes(jnp.shape(high), jnp.shape(rate))
self.high, self.rate = dist.util.promote_shapes(high, rate)
super().__init__(batch_shape, validate_args=validate_args)
def log_prob(self, value):
m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
log_p = jax.scipy.stats.poisson.logpmf(value, self.rate)
return jnp.where(value <= self.high, log_p - jnp.log(m), -jnp.inf)
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
float_type = jnp.result_type(float)
minval = jnp.finfo(float_type).tiny
u = jax.random.uniform(key, shape, minval=minval)
# return self.icdf(u) # Brute force
# return self.icdf_faster(u) # For faster sampling.
return self.icdf(u) # Using `host_callback`