Hi all,
I have a question regarding Predictive
and its use with batch dimensions.
I have modified the Poisson distribution so that it has a finite domain according to the tutorials:
RightTruncatedPoisson
import numpyro.distributions as dist
import jax
import jax.numpy as jnp
import numpyro
from numpyro.util import not_jax_tracer
import numpy as np
import scipy
class RightTruncatedPoisson(dist.Distribution):
"""
A truncated Poisson distribution.
:param numpy.ndarray high: high bound at which truncation happens
:param numpy.ndarray rate: rate of the Poisson distribution.
"""
arg_constraints = {
"high": dist.constraints.nonnegative_integer,
"rate": dist.constraints.positive,
}
has_enumerate_support = True
def __init__(self, rate=1.0, high=0, validate_args=None):
batch_shape = jax.lax.broadcast_shapes(jnp.shape(high), jnp.shape(rate))
self.high, self.rate = dist.util.promote_shapes(high, rate)
super().__init__(batch_shape, validate_args=validate_args)
def log_prob(self, value):
m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
log_p = jax.scipy.stats.poisson.logpmf(value, self.rate)
return jnp.where(value <= self.high, log_p - jnp.log(m), -jnp.inf)
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
float_type = jnp.result_type(float)
minval = jnp.finfo(float_type).tiny
u = jax.random.uniform(key, shape, minval=minval)
return self.icdf(u) # Using `host_callback`
def cdf(self, value):
m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
f = jax.scipy.stats.poisson.cdf(value, self.rate)
return jnp.where(value <= self.high, f / m, 0)
def icdf(self, u):
result_shape = jax.ShapeDtypeStruct(u.shape, jnp.result_type(float))
result = jax.experimental.host_callback.call(
scipy_truncated_poisson_icdf,
(self.rate, self.high, u),
result_shape=result_shape,
)
return result.astype(jnp.result_type(int))
@dist.constraints.dependent_property(is_discrete=True)
def support(self):
return dist.constraints.integer_greater_than(self.high)
# in order to do sampling, we have to first write a function for
# enumerate_support
def enumerate_support(self, expand=True):
if not_jax_tracer(self.high):
high = np.amax(self.high)
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if np.amin(self.high) != high:
raise NotImplementedError(
"Inhomogeneous total count not supported" " by `enumerate_support`."
)
else:
high = jnp.amax(self.high)
values = jnp.arange(high + 1).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
# adapted from https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html#3
def scipy_truncated_poisson_icdf(args): # Note: all arguments are passed inside a tuple
rate, high, u = args
rate = np.asarray(rate)
high = np.asarray(high)
u = np.asarray(u)
density = scipy.stats.poisson(rate)
normalizer = density.cdf(high)
x = normalizer * u
return density.ppf(x)
Which I use for a minimal example here:
numpyro.infer.Predictive
def model(counts, high=10):
plate = numpyro.plate('plate', size=counts.shape[-1], dim=-1)
with plate:
counts_poisson = numpyro.sample(
'counts_poisson',
RightTruncatedPoisson(
rate=counts, high=high
),
infer={"enumerate": "parallel"}
)
nbatch0, nbatch1 = 2,1
counts = jnp.ones((nbatch0, nbatch1, 100))
mock_samples = numpyro.infer.Predictive(
model,
return_sites=['counts_poisson'],
infer_discrete=True,
batch_ndims=2,
posterior_samples=dict(counts=counts)
)(jax.random.PRNGKey(3), counts=counts)
For nbatch0, nbatch1 = 1,1
this works fine, but if I set one of these variables to a value larger than one, it breaks and I get the error message:
ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_3'}.
Is this something that I can resolve with Vindex
?