What is the numpyro way of doing this?
It is coincident that I’m going to write a tutorial on how to construct a truncated distribution for discrete distributions. Just for your reference, here is how I constructed TruncatedZeroInflatedPoisson distribution
def rv_truncated_poisson(mu, mx, size=None):
mu = np.asarray(mu)
mx = np.asarray(mx)
dist = stats.distributions.poisson(mu)
lower_cdf = 0.
upper_cdf = dist.cdf(mx)
nrm = upper_cdf - lower_cdf
sample = np.random.random(size) * nrm + lower_cdf
return dist.ppf(sample)
def rv_truncated_zip(args):
rate, gate, high, shape = args
g = rv_truncated_poisson(rate, high, size=shape)
return g * (np.random.random(shape) > gate)
class TruncatedZeroInflatedPoisson(dist.Distribution):
def __init__(self, rate, gate, high, validate_args=None):
self.rate, self.gate, self.high = rate, gate, high
batch_shape = jax.lax.broadcast_shapes(
jnp.shape(rate), jnp.shape(gate), jnp.shape(high))
super().__init__(batch_shape, validate_args=None)
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
samples = jax.experimental.host_callback.call(
rv_truncated_zip, (self.rate, self.gate, self.high, shape),
result_shape=jax.ShapeDtypeStruct(shape, jnp.result_type(float)))
return samples.astype(jnp.result_type(int))
def log_prob(self, value):
upper_cdf = jax.scipy.special.gammaincc(self.high + 1, self.rate)
log_prob = dist.Poisson(self.rate).log_prob(value) - jnp.log(upper_cdf)
log_prob = jnp.log1p(-self.gate) + log_prob
return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)), log_prob)
For binomial, instead of gammaincc, you can use betainc to compute the cdf at truncated bounds. I used host_callback in sample
method because jax does not have functions to compute inverse cdf (i.e. ppf) of Poisson/Binomial yet. I believe you can do the same for truncated Binomial distribution.