Thanks a lot @fehiepsi
I think I made a class and its usage seems OK.

betainc
was really useful. I would have spent a lot of time to find it.

jax.experimental.host_callback.call()
allowed me only one function argument. A tuple type argument produced errors mentioning infeed buffersize. I simply solved it by using only one.
But when I used it in a model for MCMC, the mcmc.run()
produces ‘NotImplementedError’ with a long listing of related codes.
Probably, I think it would be better to ask you whether or not some more class methods must be implemented in addition to those three, __init__()
, sample()
, and log_prob()
in order for the class to be used for mcmc inference.
best,
from re import S
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import jax
import jax.random as random
import jax.numpy as jnp
import numpy as np
import scipy
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns;
import arviz
def rv_truncated_binomial (nparr):
# print('nparr: ', nparr)
n, p, low, high = nparr[:4]
shape = tuple(int(i) for i in nparr[4:])
# print('args: ', n, p, low, high, shape)
dist = scipy.stats.binom(n, p)
plow = dist.cdf(low1)
phigh = dist.cdf(high)
# print(f'plow: {plow} phigh: {phigh}')
r = np.random.random(size=shape) # r \in [0, 1)
mr = r * (phigh  plow) + plow # now, low <= ppf(r) <= high
# print(f'r: {r}, mr: {mr}, ppf: {dist.ppf(mr)}')
return dist.ppf(mr).astype(np.float32)
class TruncatedBinomial(dist.Distribution):
def __init__(self, total_count, probs, low=0, high=None, valid_args=None):
self.total_count, self.probs, self.low, self.high, self.valid_args = total_count, probs, low, high, valid_args
self.param = jnp.array([total_count, probs, low, high])
if high is None:
high = total_count # Use dist.Binomial for normal Binomial
super().__init__()
# normalization constant for the truncated binomial
self.Z = jax.scipy.special.betainc( a = total_count  high,
b = 1 + high,
x = probs) \
 jax.scipy.special.betainc(a = total_count  (low  1),
b = 1 + (low  1),
x = probs)
pass
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
if shape == ():
param = self.param
else:
param = jnp.concatenate((self.param, jnp.array(shape)))
# print('param: ', param)
samples = jax.experimental.host_callback.call(
rv_truncated_binomial, param,
result_shape=jax.ShapeDtypeStruct(shape, jnp.float32))
return samples.astype(jnp.int32)
def log_prob(self, value):
"""
low <= value <= high
"""
log_prob = dist.Binomial(self.total_count, probs=self.probs).log_prob(value)  self.Z
return jnp.where(self.low <= value <= self.high, log_prob, 1e10) # jnp.inf ?
#
if __name__ == "__main__":
rng_key = random.PRNGKey(0)
rng_key, sub = jax.random.split(rng_key)
lhs = [(0, 10), (3, 7), (4, 5)]
for low, high in lhs:
print(f'* low: {low}, high: {high}')
tb = TruncatedBinomial(10, .5, low=low, high=high)
s = tb.sample(rng_key, sample_shape=(100000,))
u, c = np.unique(s, return_counts=True)
print(u)
print(c/len(s))