Sample from discerte uniform with moving bounds

Hey,
I’m trying to implement a version of a discrete uniform variable, where the bounds are updated based on another RV:
w \sim N(\mu, \sigma), o_t|w,o_{t-1} \sim U(w,o_{t-1}) . The main issue here is that the uniform distribution is defined over a set, and the bound of the support updated after each iteration.
Here’s the code for my simple model:

       def model(self, observation):      
        w = numpyro.sample("labor_costs", dist.Normal(self.loc, self.scale))
        counter_action = numpyro.sample("counter_offer", DiscreteUniform(self.offers, w, observation))
        return counter_action

The main issue is that the bounds of python DiscreteUniform are dynamic. That is, given a value of w (latent) the bounds are updated. I tried implementing my version of the distribution but I think I got some fundamental thing wrong:

class DiscreteUniform(Distribution):
    arg_constraints = {"low": constraints.dependent, "high": constraints.dependent}
    reparametrized_params = ["low", "high"]

    def __init__(self, support_vector, low=None, high=None):
        super(DiscreteUniform, self).__init__()
        self.low = low
        self.high = high
        self.support_vector = support_vector
        self._support = constraints.interval(low, high)

    def _compute_feasible_interval(self):
        return self.support_vector[(self.support_vector >= self.low.primal.pval[1]) & (self.support_vector <= self.high)]

    def sample(self, key, sample_shape=()):
        res = jax.random.choice(key, self._compute_feasible_interval(), sample_shape, replace=False)
        return res

    def log_prob(self, value):
        """
        The probability for each feasible distribution is 1/|n|
        :param value:
        :return:
        """
        # feasible_support = self._compute_feasible_interval()
        feasible_support = self.support_vector
        prob = 1 / feasible_support.shape[0]
        return jnp.log(prob)

    @property
    def mean(self):
        return (self.low + self.high) / 2

    @property
    def variance(self):
        n = self.high - self.low + 1
        return (jnp.power(n, 2) - 1) / 12

    def enumerate_support(self, expand=True):
        pass

    def cdf(self, value):
        # feasible_support = self._compute_feasible_interval()
        feasible_support = self.support_vector
        prob = 1 / feasible_support.shape[0]
        return prob

    def icdf(self, q):
        pass

I’d be happy to comments on how to update the support dynamically (from [a,b] to [a’,b’]).
Thanks in advance!

I would recommend reparameterize your model because numpyro does not have good support for dynamic supports. A reparam version of your model is

w ~ N(mu, sigma)
base_o ~ U(0, 1).expand([T])
current_o = base_o[0]
o = [current_o]
for current_base_o in base_o[1:]:  # use lax.scan for this loop
    # the following line should be adjusted to reflect how
    # you want to set bounds for your discrete uniform samples
    current_o = w + (current_o - w) * current_base_o  
    o.append(current_o)

@fehiepsi Thank you very much.
I’m somewhat confused, as in the discrete case I’m able to sample from both o_{t+1} as well as from w|o_t so where’s the problem in the model as currently specified?
In your suggested solution - what would be an efficient way to trim the array of possible values? Would you recommend implementing a Discrete Uniform distribution, similar in principle to your implementation of the Truncated Binomial (Truncated Binomial, How to implement?)?
Thanks for the kind help!

I don’t quite understand your feasible interval but to take a discrete sample from [a, a + n], you can do

u ~ Uniform()
low = jnp.ceil(a)
high = jnp.floor(a) + n
sample = jnp.floor(low + (high + 1 - low) * u)

here u is an element of base_o in my last comment.

Regarding o, you can add

o = numpyro.deterministic("o", jnp.concatenate(o).astype(jnp.float32))

to record its value.

Both models (the original model and the reparam model) should be equivalent in terms of its generative perspective (i.e. w ~ N(mu, sigma), o_t | w, o_t-1 ~ U(w, o_t-1). But numpyro inference methods would only work for the reparam model (which has continuous latent variables and static supports).

Thank you very much.

I don’t quite understand your feasible interval

Consider a guessing game in which an agent selects a number between low and high. The other agent asks questions about this number and only gets answers in the form of “higher” and “lower”. If you know that the number is a integer in the set [low, high] that after each trail the set of potential values is truncate based on the response. Now the answer is passed through a Gaussian channel, hence the Gaussian lower variable. Is there an intention to implement dynamic truncation like in Stan?
Thanks for the kind help!

Does Stan support dynamic truncation for discrete latent variables? Could you point me to some reference? For continuous latent variables, it looks like Stan uses the above approach.

note that, at least for some problems, you can always convert to a problem with a fixed discrete domain by padding your probabilities appropriately with zeros or negligible probabilities. e.g. the histogram [0.25, 0.5, 0.25] becomes [0.0, 0.0, 0.25, 0.5, 0.25, 0.0, 0.0, 0.0] or what have you

2 Likes

Indeed - I probably misunderstood this: the lower and upper bound are sampled from a known distribution but are latent and I assume (falsely) updated after each observation