Enumerate support for batch dimensions of custom distribution

Hi all,

I have a question regarding Predictive and its use with batch dimensions.

I have modified the Poisson distribution so that it has a finite domain according to the tutorials:

RightTruncatedPoisson
import numpyro.distributions as dist
import jax
import jax.numpy as jnp
import numpyro
from numpyro.util import not_jax_tracer
import numpy as np
import scipy

class RightTruncatedPoisson(dist.Distribution):
    """
    A truncated Poisson distribution.
    :param numpy.ndarray high: high bound at which truncation happens
    :param numpy.ndarray rate: rate of the Poisson distribution.
    """

    arg_constraints = {
        "high": dist.constraints.nonnegative_integer,
        "rate": dist.constraints.positive,
    }
    has_enumerate_support = True

    def __init__(self, rate=1.0, high=0, validate_args=None):
        batch_shape = jax.lax.broadcast_shapes(jnp.shape(high), jnp.shape(rate))
        self.high, self.rate = dist.util.promote_shapes(high, rate)
        super().__init__(batch_shape, validate_args=validate_args)

    def log_prob(self, value):
        m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
        log_p = jax.scipy.stats.poisson.logpmf(value, self.rate)
        return jnp.where(value <= self.high, log_p - jnp.log(m), -jnp.inf)

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        float_type = jnp.result_type(float)
        minval = jnp.finfo(float_type).tiny
        u = jax.random.uniform(key, shape, minval=minval)
        return self.icdf(u)  # Using `host_callback`
    
    def cdf(self, value):
        m = jax.scipy.stats.poisson.cdf(self.high, self.rate)
        f = jax.scipy.stats.poisson.cdf(value, self.rate)
        return jnp.where(value <= self.high, f / m, 0)
    
    def icdf(self, u):
        result_shape = jax.ShapeDtypeStruct(u.shape, jnp.result_type(float))
        result = jax.experimental.host_callback.call(
            scipy_truncated_poisson_icdf,
            (self.rate, self.high, u),
            result_shape=result_shape,
        )
        return result.astype(jnp.result_type(int))

    @dist.constraints.dependent_property(is_discrete=True)
    def support(self):
        return dist.constraints.integer_greater_than(self.high)
    
    # in order to do sampling, we have to first write a function for 
    # enumerate_support
    def enumerate_support(self, expand=True):
        if not_jax_tracer(self.high):
            high = np.amax(self.high)
            # NB: the error can't be raised if inhomogeneous issue happens when tracing
            if np.amin(self.high) != high:
                raise NotImplementedError(
                    "Inhomogeneous total count not supported" " by `enumerate_support`."
                )
        else:
            high = jnp.amax(self.high)
        values = jnp.arange(high + 1).reshape(
            (-1,) + (1,) * len(self.batch_shape)
        )
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values

# adapted from https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html#3
def scipy_truncated_poisson_icdf(args):  # Note: all arguments are passed inside a tuple
    rate, high, u = args
    rate = np.asarray(rate)
    high = np.asarray(high)
    u = np.asarray(u)
    density = scipy.stats.poisson(rate)
    normalizer = density.cdf(high)
    x = normalizer * u
    return density.ppf(x)

Which I use for a minimal example here:

numpyro.infer.Predictive
def model(counts, high=10):

    plate = numpyro.plate('plate', size=counts.shape[-1], dim=-1)

    with plate:
        counts_poisson = numpyro.sample(
            'counts_poisson',
            RightTruncatedPoisson(
                rate=counts, high=high
            ), 
            infer={"enumerate": "parallel"}
        )

nbatch0, nbatch1 = 2,1
counts = jnp.ones((nbatch0, nbatch1, 100))

mock_samples = numpyro.infer.Predictive(
    model, 
    return_sites=['counts_poisson'], 
    infer_discrete=True,
    batch_ndims=2,
    posterior_samples=dict(counts=counts)
)(jax.random.PRNGKey(3), counts=counts)

For nbatch0, nbatch1 = 1,1 this works fine, but if I set one of these variables to a value larger than one, it breaks and I get the error message:

ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_3'}.

Is this something that I can resolve with Vindex?

I think you want to set Predictive(..., parallel=True). edit: this is not correct

In your code, counts is not a random variable. You might want to rewrite the model into something like

def model(counts):
    counts = numpyro.deterministic("counts", counts)
    ...

# then call
numpyro.infer.Predictive(..., exclude_deterministic=False)

Thanks for your answer. Unfortunately, I still get the same error, even if I use the solution with deterministic and exclude_deterministic=False as proposed.

It seems that the exclude_deterministic does not apply for models with discrete latent variables. Could you make a github issue?

Edit: i made a github issue here Exclude_deterministic argument in Predictive does not apply for models with discrete latents · Issue #1861 · pyro-ppl/numpyro · GitHub

1 Like