sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategorical(1, char_dist), obs=observed[index]).squeeze(0)

I have this code where sampled_indexes becomes a tensor of probabilities over a categorical distribution of size 32 rather than just sampling an index, how do I get this to just sample an index instead?

I implemented the following distribution to get indices - although I must say that I am not %100 sure if it is correct. It has similarities with RelaxedOneHotCategorical and RelaxedOneHotCategoricalStraightThrough.

```
class RelaxedCategoricalStraightThrough(TransformedDistribution):
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
support = constraints.simplex
has_rsample = True
def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
super(RelaxedCategoricalStraightThrough, self).__init__(base_dist,
ExpTransform(),
validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedCategoricalStraightThrough, _instance)
return super(RelaxedCategoricalStraightThrough, self).expand(batch_shape, _instance=new)
@property
def temperature(self):
return self.base_dist.temperature
@property
def logits(self):
return self.base_dist.logits
@property
def probs(self):
return self.base_dist.probs
def rsample(self, sample_shape=torch.Size()):
soft_sample = super().rsample(sample_shape)
soft_sample = clamp_probs(soft_sample)
hard_sample = QuantizeCategorical.apply(soft_sample)
return hard_sample
def log_prob(self, value):
value = getattr(value, "_unquantize", value)
return super().log_prob(value)
class QuantizeCategorical(torch.autograd.Function):
@staticmethod
def forward(ctx, soft_value):
argmax = soft_value.max(-1)[1]
if argmax.dim() < soft_value.dim():
argmax = argmax.unsqueeze(-1)
hard_value = torch.zeros_like(argmax)
hard_value._unquantize = soft_value
return hard_value.copy_(argmax)
@staticmethod
def backward(ctx, grad):
return grad
```

I’m assuming y’all want to sample at prediction time, after you’ve already trained with a relaxed distribution?

You can’t sample indices from a relaxed categorical distribution because the relaxation happens in a bigger space, the space of vectors rather than the set of integers. Two things you can do are:

- Define a relaxed ordinal distribution that samples e.g. from a Beta instead of a Binomial. This is possible because there is a natural embedding of the index space (e.g. {0,1,2,3}) in to the relaxed space (e.g. [0,3]). No such embedding exists for non-ordered categorical variables.
- If your distribution doesn’t have ordinal structure, you can simply use a
`Categorical`

distribution for prediction. Copy the `.probs`

or `.logits`

over and create a new distribution.

@fritzo Thanks for the reply. I understand your point.

I was trying to sample during training. I thought a relaxed distribution would be suitable for my case in order to avoid the cost of enumeration. Because I aim to predict which category (`m`

=90) each data point (`n`

=60) belongs to, and `m`

and `n`

are high. Besides, I have seen in my experiments that using my custom `RelaxedCategoricalStraightThrough`

distribution worked better than using `Categorical`

distribution (without enumeration). What could be the reason behind this?

If someone should use a `Categorical`

distribution already, why should a new distribution be created?

Sorry, I don’t fully understand your `RelaxedCategoricalStraightThrough`

. The shapes don’t quite make sense to me, e.g. `.support = simplex`

but you `hard_value`

looks like it has a different shape. Similarly I don’t understand how `QuantizeCategorical.backward = lambda ctx, grad: grad`

can get shapes right But I’m not the best person to ask about relaxed categorical distributions

I was thinking you could use a relaxed distribution for training, then a regular `Categorical`

distribution for prediction.