# Truncated Binomial, How to implement?

Hello, I am trying to make a numpyro implementation of the following WinBUGS code for a censored data:

# ChaSaSoon Censored Data
model{
for (i in 1:nattempts){
# If the Data Were Unobserved y[i]=1, Otherwise y[i]=0
z.low[i]  <- 15*equals(y[i],1)+0*equals(y[i],0)
z.high[i] <- 25*equals(y[i],1)+n*equals(y[i],0)
z[i] ~ dbin(theta,n)I(z.low[i],z.high[i])
}
# Uniform Prior on Rate Theta
theta ~ dbeta(1,1)I(.25,1)
}


My difficulty is in translating the code

    z[i] ~ dbin(theta, n) I(z.low[i], z.high[i])


which means that z[i] is a binomial sample whose range is constrained to be in the interval [z.low[i], z.high[i]].

Specifically, the sample must be in [15, 25] from Binomial(theta, n=50).

Q1. Is there a way to use constraints.Integer_interval? I tried to find an example but in vain.

Q2. What is the numpyro way of doing this?

• The code is from https://bayesmodels.com/.
• numpyro’s TruncatedDistribution does not apply because it is for continuous variables only.

What is the numpyro way of doing this?

It is coincident that I’m going to write a tutorial on how to construct a truncated distribution for discrete distributions. Just for your reference, here is how I constructed TruncatedZeroInflatedPoisson distribution

def rv_truncated_poisson(mu, mx, size=None):
mu = np.asarray(mu)
mx = np.asarray(mx)
dist = stats.distributions.poisson(mu)

lower_cdf = 0.
upper_cdf = dist.cdf(mx)
nrm = upper_cdf - lower_cdf
sample = np.random.random(size) * nrm + lower_cdf

return dist.ppf(sample)

def rv_truncated_zip(args):
rate, gate, high, shape = args
g = rv_truncated_poisson(rate, high, size=shape)
return g * (np.random.random(shape) > gate)

class TruncatedZeroInflatedPoisson(dist.Distribution):

def __init__(self, rate, gate, high, validate_args=None):
self.rate, self.gate, self.high = rate, gate, high
jnp.shape(rate), jnp.shape(gate), jnp.shape(high))
super().__init__(batch_shape, validate_args=None)

def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
samples = jax.experimental.host_callback.call(
rv_truncated_zip, (self.rate, self.gate, self.high, shape),
result_shape=jax.ShapeDtypeStruct(shape, jnp.result_type(float)))
return samples.astype(jnp.result_type(int))

def log_prob(self, value):
upper_cdf = jax.scipy.special.gammaincc(self.high + 1, self.rate)
log_prob = dist.Poisson(self.rate).log_prob(value) - jnp.log(upper_cdf)
log_prob = jnp.log1p(-self.gate) + log_prob
return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)), log_prob)


For binomial, instead of gammaincc, you can use betainc to compute the cdf at truncated bounds. I used host_callback in sample method because jax does not have functions to compute inverse cdf (i.e. ppf) of Poisson/Binomial yet. I believe you can do the same for truncated Binomial distribution.

Thanks a lot @fehiepsi

I think I made a class and its usage seems OK.

• betainc was really useful. I would have spent a lot of time to find it.
• jax.experimental.host_callback.call() allowed me only one function argument. A tuple type argument produced errors mentioning infeed buffersize. I simply solved it by using only one.

But when I used it in a model for MCMC, the mcmc.run() produces ‘NotImplementedError’ with a long listing of related codes.

Probably, I think it would be better to ask you whether or not some more class methods must be implemented in addition to those three, __init__(), sample(), and log_prob() in order for the class to be used for mcmc inference.

best,

from re import S
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import jax
import jax.random as random
import jax.numpy as jnp

import numpy as np
import scipy
import scipy.stats as stats
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns;
import arviz

def rv_truncated_binomial (nparr):
# print('nparr: ', nparr)
n, p, low, high = nparr[:4]
shape = tuple(int(i) for i in nparr[4:])
# print('args: ', n, p, low, high, shape)

dist = scipy.stats.binom(n, p)
plow = dist.cdf(low-1)
phigh = dist.cdf(high)
# print(f'plow: {plow} phigh: {phigh}')

r = np.random.random(size=shape)  #  r \in  [0, 1)
mr = r * (phigh - plow) + plow  # now,  low <= ppf(r) <= high
# print(f'r: {r}, mr: {mr}, ppf: {dist.ppf(mr)}')
return dist.ppf(mr).astype(np.float32)

class TruncatedBinomial(dist.Distribution):
def __init__(self, total_count, probs, low=0, high=None, valid_args=None):
self.total_count, self.probs, self.low, self.high, self.valid_args = total_count, probs, low, high, valid_args
self.param = jnp.array([total_count, probs, low, high])
if high is None:
high = total_count  # Use dist.Binomial for normal Binomial
super().__init__()

# normalization constant for the truncated binomial
self.Z = jax.scipy.special.betainc( a = total_count - high,
b = 1 + high,
x = probs) \
- jax.scipy.special.betainc(a = total_count - (low - 1),
b = 1 + (low - 1),
x = probs)
pass

def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
if shape == ():
param = self.param
else:
param = jnp.concatenate((self.param, jnp.array(shape)))
# print('param: ', param)
samples = jax.experimental.host_callback.call(
rv_truncated_binomial, param,
result_shape=jax.ShapeDtypeStruct(shape, jnp.float32))
return samples.astype(jnp.int32)

def log_prob(self, value):
"""
low <= value <= high
"""
log_prob = dist.Binomial(self.total_count, probs=self.probs).log_prob(value) - self.Z
return jnp.where(self.low <= value <= self.high, log_prob, 1e-10)  # jnp.inf ?
#

if __name__ == "__main__":
rng_key = random.PRNGKey(0)
rng_key, sub = jax.random.split(rng_key)

lhs = [(0, 10), (3, 7), (4, 5)]

for low, high in lhs:
print(f'* low: {low}, high: {high}')
tb = TruncatedBinomial(10, .5, low=low, high=high)
s = tb.sample(rng_key, sample_shape=(100000,))

u, c = np.unique(s, return_counts=True)
print(u)
print(c/len(s))


Are you using TruncatedBinomial as a latent variable or an observed variable? Could you make some reproducible code for the error? FYI, I used TruncatedZeroInflatedPoisson as an observed site in MCMC without problems.