Truncated Binomial, How to implement?

What is the numpyro way of doing this?

It is coincident that I’m going to write a tutorial on how to construct a truncated distribution for discrete distributions. Just for your reference, here is how I constructed TruncatedZeroInflatedPoisson distribution

def rv_truncated_poisson(mu, mx, size=None):
    mu = np.asarray(mu)
    mx = np.asarray(mx)
    dist = stats.distributions.poisson(mu)

    lower_cdf = 0.
    upper_cdf = dist.cdf(mx)
    nrm = upper_cdf - lower_cdf
    sample = np.random.random(size) * nrm + lower_cdf

    return dist.ppf(sample)


def rv_truncated_zip(args):
    rate, gate, high, shape = args
    g = rv_truncated_poisson(rate, high, size=shape)
    return g * (np.random.random(shape) > gate)


class TruncatedZeroInflatedPoisson(dist.Distribution):

    def __init__(self, rate, gate, high, validate_args=None):
        self.rate, self.gate, self.high = rate, gate, high
        batch_shape = jax.lax.broadcast_shapes(
            jnp.shape(rate), jnp.shape(gate), jnp.shape(high))
        super().__init__(batch_shape, validate_args=None)

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        samples = jax.experimental.host_callback.call(
            rv_truncated_zip, (self.rate, self.gate, self.high, shape),
            result_shape=jax.ShapeDtypeStruct(shape, jnp.result_type(float)))
        return samples.astype(jnp.result_type(int))

    def log_prob(self, value):
        upper_cdf = jax.scipy.special.gammaincc(self.high + 1, self.rate)
        log_prob = dist.Poisson(self.rate).log_prob(value) - jnp.log(upper_cdf)
        log_prob = jnp.log1p(-self.gate) + log_prob
        return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)), log_prob)

For binomial, instead of gammaincc, you can use betainc to compute the cdf at truncated bounds. I used host_callback in sample method because jax does not have functions to compute inverse cdf (i.e. ppf) of Poisson/Binomial yet. I believe you can do the same for truncated Binomial distribution.