Hey,
I’m trying to implement a version of a discrete uniform variable, where the bounds are updated based on another RV:
w \sim N(\mu, \sigma), o_t|w,o_{t-1} \sim U(w,o_{t-1}) . The main issue here is that the uniform distribution is defined over a set, and the bound of the support updated after each iteration.
Here’s the code for my simple model:
def model(self, observation):
w = numpyro.sample("labor_costs", dist.Normal(self.loc, self.scale))
counter_action = numpyro.sample("counter_offer", DiscreteUniform(self.offers, w, observation))
return counter_action
The main issue is that the bounds of python DiscreteUniform
are dynamic. That is, given a value of w (latent) the bounds are updated. I tried implementing my version of the distribution but I think I got some fundamental thing wrong:
class DiscreteUniform(Distribution):
arg_constraints = {"low": constraints.dependent, "high": constraints.dependent}
reparametrized_params = ["low", "high"]
def __init__(self, support_vector, low=None, high=None):
super(DiscreteUniform, self).__init__()
self.low = low
self.high = high
self.support_vector = support_vector
self._support = constraints.interval(low, high)
def _compute_feasible_interval(self):
return self.support_vector[(self.support_vector >= self.low.primal.pval[1]) & (self.support_vector <= self.high)]
def sample(self, key, sample_shape=()):
res = jax.random.choice(key, self._compute_feasible_interval(), sample_shape, replace=False)
return res
def log_prob(self, value):
"""
The probability for each feasible distribution is 1/|n|
:param value:
:return:
"""
# feasible_support = self._compute_feasible_interval()
feasible_support = self.support_vector
prob = 1 / feasible_support.shape[0]
return jnp.log(prob)
@property
def mean(self):
return (self.low + self.high) / 2
@property
def variance(self):
n = self.high - self.low + 1
return (jnp.power(n, 2) - 1) / 12
def enumerate_support(self, expand=True):
pass
def cdf(self, value):
# feasible_support = self._compute_feasible_interval()
feasible_support = self.support_vector
prob = 1 / feasible_support.shape[0]
return prob
def icdf(self, q):
pass
I’d be happy to comments on how to update the support dynamically (from [a,b] to [a’,b’]).
Thanks in advance!