Pyro.sample from RelaxedOneHotCategorical doesn't return sampled indexes

sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategorical(1, char_dist), obs=observed[index]).squeeze(0)

I have this code where sampled_indexes becomes a tensor of probabilities over a categorical distribution of size 32 rather than just sampling an index, how do I get this to just sample an index instead?

I implemented the following distribution to get indices - although I must say that I am not %100 sure if it is correct. It has similarities with RelaxedOneHotCategorical and RelaxedOneHotCategoricalStraightThrough.

class RelaxedCategoricalStraightThrough(TransformedDistribution):
    arg_constraints = {'probs': constraints.simplex,
                       'logits': constraints.real_vector}
    support = constraints.simplex
    has_rsample = True

    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
        base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
        super(RelaxedCategoricalStraightThrough, self).__init__(base_dist,
                                                       ExpTransform(),
                                                       validate_args=validate_args)
    
    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(RelaxedCategoricalStraightThrough, _instance)
        return super(RelaxedCategoricalStraightThrough, self).expand(batch_shape, _instance=new)

    @property
    def temperature(self):
        return self.base_dist.temperature

    @property
    def logits(self):
        return self.base_dist.logits

    @property
    def probs(self):
        return self.base_dist.probs

    def rsample(self, sample_shape=torch.Size()):
        soft_sample = super().rsample(sample_shape)
        soft_sample = clamp_probs(soft_sample)
        hard_sample = QuantizeCategorical.apply(soft_sample)
        return hard_sample

    def log_prob(self, value):
        value = getattr(value, "_unquantize", value)
        return super().log_prob(value)


class QuantizeCategorical(torch.autograd.Function):
    @staticmethod
    def forward(ctx, soft_value):
        argmax = soft_value.max(-1)[1]
        if argmax.dim() < soft_value.dim():
            argmax = argmax.unsqueeze(-1)
        hard_value = torch.zeros_like(argmax)
        hard_value._unquantize = soft_value

        return hard_value.copy_(argmax)

    @staticmethod
    def backward(ctx, grad):
        return grad

I’m assuming y’all want to sample at prediction time, after you’ve already trained with a relaxed distribution?

You can’t sample indices from a relaxed categorical distribution because the relaxation happens in a bigger space, the space of vectors rather than the set of integers. Two things you can do are:

  1. Define a relaxed ordinal distribution that samples e.g. from a Beta instead of a Binomial. This is possible because there is a natural embedding of the index space (e.g. {0,1,2,3}) in to the relaxed space (e.g. [0,3]). No such embedding exists for non-ordered categorical variables.
  2. If your distribution doesn’t have ordinal structure, you can simply use a Categorical distribution for prediction. Copy the .probs or .logits over and create a new distribution.

@fritzo Thanks for the reply. I understand your point.

I was trying to sample during training. I thought a relaxed distribution would be suitable for my case in order to avoid the cost of enumeration. Because I aim to predict which category (m=90) each data point (n=60) belongs to, and m and n are high. Besides, I have seen in my experiments that using my custom RelaxedCategoricalStraightThrough distribution worked better than using Categorical distribution (without enumeration). What could be the reason behind this?

If someone should use a Categorical distribution already, why should a new distribution be created?

Sorry, I don’t fully understand your RelaxedCategoricalStraightThrough. The shapes don’t quite make sense to me, e.g. .support = simplex but you hard_value looks like it has a different shape. Similarly I don’t understand how QuantizeCategorical.backward = lambda ctx, grad: grad can get shapes right :confused: But I’m not the best person to ask about relaxed categorical distributions :laughing:

I was thinking you could use a relaxed distribution for training, then a regular Categorical distribution for prediction.