sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategorical(1, char_dist), obs=observed[index]).squeeze(0)
I have this code where sampled_indexes becomes a tensor of probabilities over a categorical distribution of size 32 rather than just sampling an index, how do I get this to just sample an index instead?
I implemented the following distribution to get indices - although I must say that I am not %100 sure if it is correct. It has similarities with RelaxedOneHotCategorical and RelaxedOneHotCategoricalStraightThrough.
class RelaxedCategoricalStraightThrough(TransformedDistribution):
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
super(RelaxedCategoricalStraightThrough, self).__init__(base_dist,
ExpTransform(),
validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedCategoricalStraightThrough, _instance)
return super(RelaxedCategoricalStraightThrough, self).expand(batch_shape, _instance=new)
@property
def temperature(self):
return self.base_dist.temperature
@property
def logits(self):
return self.base_dist.logits
@property
def probs(self):
return self.base_dist.probs
def rsample(self, sample_shape=torch.Size()):
soft_sample = super().rsample(sample_shape)
soft_sample = clamp_probs(soft_sample)
hard_sample = QuantizeCategorical.apply(soft_sample)
return hard_sample
def log_prob(self, value):
value = getattr(value, "_unquantize", value)
return super().log_prob(value)
class QuantizeCategorical(torch.autograd.Function):
@staticmethod
def forward(ctx, soft_value):
argmax = soft_value.max(-1)[1]
if argmax.dim() < soft_value.dim():
argmax = argmax.unsqueeze(-1)
hard_value = torch.zeros_like(argmax)
hard_value._unquantize = soft_value
return hard_value.copy_(argmax)
@staticmethod
def backward(ctx, grad):
return grad
I’m assuming y’all want to sample at prediction time, after you’ve already trained with a relaxed distribution?
You can’t sample indices from a relaxed categorical distribution because the relaxation happens in a bigger space, the space of vectors rather than the set of integers. Two things you can do are:
- Define a relaxed ordinal distribution that samples e.g. from a Beta instead of a Binomial. This is possible because there is a natural embedding of the index space (e.g. {0,1,2,3}) in to the relaxed space (e.g. [0,3]). No such embedding exists for non-ordered categorical variables.
- If your distribution doesn’t have ordinal structure, you can simply use a
Categorical
distribution for prediction. Copy the .probs
or .logits
over and create a new distribution.
@fritzo Thanks for the reply. I understand your point.
I was trying to sample during training. I thought a relaxed distribution would be suitable for my case in order to avoid the cost of enumeration. Because I aim to predict which category (m
=90) each data point (n
=60) belongs to, and m
and n
are high. Besides, I have seen in my experiments that using my custom RelaxedCategoricalStraightThrough
distribution worked better than using Categorical
distribution (without enumeration). What could be the reason behind this?
If someone should use a Categorical
distribution already, why should a new distribution be created?
Sorry, I don’t fully understand your RelaxedCategoricalStraightThrough
. The shapes don’t quite make sense to me, e.g. .support = simplex
but you hard_value
looks like it has a different shape. Similarly I don’t understand how QuantizeCategorical.backward = lambda ctx, grad: grad
can get shapes right But I’m not the best person to ask about relaxed categorical distributions
I was thinking you could use a relaxed distribution for training, then a regular Categorical
distribution for prediction.