Trying to write a Uniform Discrete distribution

The Categorical distribution just returns indices on the domain. For what I’m trying to do I need to bind actual domain values to the trace, so I can condition on those values later. For example if my domain is the integers {3, 6, 9}, I want a sample statement to bind those values to the trace instead of {0, 1, 2}.

So I wrote the following discrete uniform version of Categorical.

from pyro.distributions import Categorical
class UniformDiscrete(Categorical):
    def __init__(self, vals):
        weights = torch.ones(vals.size())
        self._categorical = Categorical(weights)
        self._vals = vals
    def sample(self, sample_shape=torch.Size()):
        idx = self._categorical.sample()
        return self._vals[idx]
    def log_prob(self, value):
        zeros = torch.zeros(value.size(0))
        self._categorical.log_prob(zeros)

Seems to work, but I’m a bit new to thinking in terms of Torch’s tensors, so I’m not sure. Am I in the ballpark? Thanks.

well for one, your log_prob is completely ignoring the input. why do you need to wrap the categorical distribution instead of using the sample values as indices? something like (from your example):

vals = torch.tensor([3, 6, 9])
idx = pyro.sample('index', dist.Categorical(weights))
value_you_need = vals[idx.long()]

then to calculate the log probability of a sample, do the reverse:

idx = (vals == sample).nonzero().float()  # this assumes each category is unique
pyro.sample('obs', dist.Categorical(weights), obs=idx)

log_prob is completely ignoring the input
It’s a uniform distribution, should they should all have the same log_prob, so I was trying to pass it the first value in the domain, 0.

why do you need to wrap the categorical distribution instead of using the sample values as indices?
Looks like your solution works, but I’m working a canonical problem where I want to be picky in what names get used in the program trace.

Using your code, I made the following changes, does this look okay?

Thanks for your feedback.

class CategoricalVal(Categorical):
    def __init__(self, vals, probs):
        self.categorical = Categorical(probs)
        self.vals = vals
    def sample(self, sample_shape=torch.Size()):
        idx = self.categorical.sample()
        return self.vals[idx.long()]
    def log_prob(self, value):
        idx = (self.vals == value).nonzero().float()
        self.categorical.log_prob(idx)

since youre inheriting from the categorical class, it may be more pythonic to write:

class CategoricalVal(Categorical):
    def __init__(self, vals, probs):
        self.vals = vals
        super(CategoricalVal, self).__init__(probs)

    def sample(self, sample_shape=torch.Size()):
        sample = super(CategoricalVal, self).sample()
        return self.vals[sample.long()]

    def support(self):
        ...
  1. Use composition rather than inheritance.
  2. Note that Categorical values are LongTensors, so there’s no need to .float().

How about

import pyro.distributions as dist
class CategoricalVals(dist.TorchDistribution):
    def __init__(self, vals, probs):
        self.vals = vals
        self.categorical = dist.Categorical(probs)
        super(CategoricalVals, self).__init__(self.categorical.batch_shape,
                                              self.categorical.event_shape)
    def sample(self, sample_shape=torch.Size()):
        return self.vals[self.categorical.sample(sample_shape)]
     def log_prob(self, value):
        idx = (self.vals == value).nonzero()
        return self.categorical.log_prob(idx)

BTW what is your use case? What kind of values are you setting? Are they tensors? What is their type?