Truncated Binomial, How to implement?

Hello, I am trying to make a numpyro implementation of the following WinBUGS code for a censored data:

# ChaSaSoon Censored Data
  for (i in 1:nattempts){
    # If the Data Were Unobserved y[i]=1, Otherwise y[i]=0   
    z.low[i]  <- 15*equals(y[i],1)+0*equals(y[i],0)
    z.high[i] <- 25*equals(y[i],1)+n*equals(y[i],0)
    z[i] ~ dbin(theta,n)I(z.low[i],z.high[i])
  # Uniform Prior on Rate Theta
  theta ~ dbeta(1,1)I(.25,1)

My difficulty is in translating the code

    z[i] ~ dbin(theta, n) I(z.low[i], z.high[i])

which means that z[i] is a binomial sample whose range is constrained to be in the interval [z.low[i], z.high[i]].

Specifically, the sample must be in [15, 25] from Binomial(theta, n=50).

Q1. Is there a way to use constraints.Integer_interval? I tried to find an example but in vain.

Q2. What is the numpyro way of doing this?

Many thanks for advice.

  • The code is from
  • numpyro’s TruncatedDistribution does not apply because it is for continuous variables only.

What is the numpyro way of doing this?

It is coincident that I’m going to write a tutorial on how to construct a truncated distribution for discrete distributions. Just for your reference, here is how I constructed TruncatedZeroInflatedPoisson distribution

def rv_truncated_poisson(mu, mx, size=None):
    mu = np.asarray(mu)
    mx = np.asarray(mx)
    dist = stats.distributions.poisson(mu)

    lower_cdf = 0.
    upper_cdf = dist.cdf(mx)
    nrm = upper_cdf - lower_cdf
    sample = np.random.random(size) * nrm + lower_cdf

    return dist.ppf(sample)

def rv_truncated_zip(args):
    rate, gate, high, shape = args
    g = rv_truncated_poisson(rate, high, size=shape)
    return g * (np.random.random(shape) > gate)

class TruncatedZeroInflatedPoisson(dist.Distribution):

    def __init__(self, rate, gate, high, validate_args=None):
        self.rate, self.gate, self.high = rate, gate, high
        batch_shape = jax.lax.broadcast_shapes(
            jnp.shape(rate), jnp.shape(gate), jnp.shape(high))
        super().__init__(batch_shape, validate_args=None)

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        samples =
            rv_truncated_zip, (self.rate, self.gate, self.high, shape),
            result_shape=jax.ShapeDtypeStruct(shape, jnp.result_type(float)))
        return samples.astype(jnp.result_type(int))

    def log_prob(self, value):
        upper_cdf = jax.scipy.special.gammaincc(self.high + 1, self.rate)
        log_prob = dist.Poisson(self.rate).log_prob(value) - jnp.log(upper_cdf)
        log_prob = jnp.log1p(-self.gate) + log_prob
        return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)), log_prob)

For binomial, instead of gammaincc, you can use betainc to compute the cdf at truncated bounds. I used host_callback in sample method because jax does not have functions to compute inverse cdf (i.e. ppf) of Poisson/Binomial yet. I believe you can do the same for truncated Binomial distribution.

Thanks a lot @fehiepsi

I think I made a class and its usage seems OK.

  • betainc was really useful. I would have spent a lot of time to find it.
  • allowed me only one function argument. A tuple type argument produced errors mentioning infeed buffersize. I simply solved it by using only one.

But when I used it in a model for MCMC, the produces ‘NotImplementedError’ with a long listing of related codes.

Probably, I think it would be better to ask you whether or not some more class methods must be implemented in addition to those three, __init__(), sample(), and log_prob() in order for the class to be used for mcmc inference.


from re import S
import numpyro 
import numpyro.distributions as dist 
from numpyro.infer import MCMC, NUTS, Predictive
import jax 
import jax.random as random 
import jax.numpy as jnp 

import numpy as np
import scipy
import scipy.stats as stats
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns; 
import arviz

def rv_truncated_binomial (nparr):
    # print('nparr: ', nparr)
    n, p, low, high = nparr[:4]
    shape = tuple(int(i) for i in nparr[4:])
    # print('args: ', n, p, low, high, shape)

    dist = scipy.stats.binom(n, p)
    plow = dist.cdf(low-1)
    phigh = dist.cdf(high)
    # print(f'plow: {plow} phigh: {phigh}')
    r = np.random.random(size=shape)  #  r \in  [0, 1)
    mr = r * (phigh - plow) + plow  # now,  low <= ppf(r) <= high
    # print(f'r: {r}, mr: {mr}, ppf: {dist.ppf(mr)}')
    return dist.ppf(mr).astype(np.float32)

class TruncatedBinomial(dist.Distribution):
    def __init__(self, total_count, probs, low=0, high=None, valid_args=None):
        self.total_count, self.probs, self.low, self.high, self.valid_args = total_count, probs, low, high, valid_args
        self.param = jnp.array([total_count, probs, low, high])
        if high is None:
            high = total_count  # Use dist.Binomial for normal Binomial

        # normalization constant for the truncated binomial
        self.Z = jax.scipy.special.betainc( a = total_count - high, 
                                            b = 1 + high, 
                                            x = probs) \
                - jax.scipy.special.betainc(a = total_count - (low - 1),
                                            b = 1 + (low - 1),
                                            x = probs)
    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        if shape == ():
            param = self.param
            param = jnp.concatenate((self.param, jnp.array(shape)))
        # print('param: ', param)
        samples =
                                        rv_truncated_binomial, param,
                                        result_shape=jax.ShapeDtypeStruct(shape, jnp.float32))
        return samples.astype(jnp.int32)
    def log_prob(self, value):
            low <= value <= high
        log_prob = dist.Binomial(self.total_count, probs=self.probs).log_prob(value) - self.Z
        return jnp.where(self.low <= value <= self.high, log_prob, 1e-10)  # jnp.inf ?

if __name__ == "__main__":
    rng_key = random.PRNGKey(0)
    rng_key, sub = jax.random.split(rng_key)

    lhs = [(0, 10), (3, 7), (4, 5)]

    for low, high in lhs:
        print(f'* low: {low}, high: {high}')
        tb = TruncatedBinomial(10, .5, low=low, high=high)
        s = tb.sample(rng_key, sample_shape=(100000,))

        u, c = np.unique(s, return_counts=True)

Are you using TruncatedBinomial as a latent variable or an observed variable? Could you make some reproducible code for the error? FYI, I used TruncatedZeroInflatedPoisson as an observed site in MCMC without problems.