What does logits= mean in sample('obs', dist.Categorical(logits= ), obs=)?


What does logits= mean in pyro.sample(‘obs’, dist.Categorical(logits= ), obs=)? Is it the input to the softmax function?

For the following example, should I use logits= or not?

The network:
class NN(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
       self.out = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        output = self.fc1(x)
        output = F.relu(output)
        output = self.out(output)
        return output

net = NN(28*28, 1024, 10)

The model():

def model(x_data, y_data):

fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))

outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))

priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'out.weight': outw_prior, 'out.bias': outb_prior}
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
# sample a regressor (which also samples w and b)
lifted_reg_model = lifted_module()

lhat = log_softmax(lifted_reg_model(x_data))

pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

In the above, by the definition of logits (the input to the softmax function), lhat is the output of softmax and thus not logits. If so, what keyword should I use in place of logits?

If I change the network to:

    class NN(nn.Module):
        def forward(self, x):
            output = self.out(output)
            output = log_softmax()
            return output

and define lhat = lifted_reg_model(x_data) instead in model(), then what should I use in Categorical()? Thanks.


Yes, it is the input into the softmax function. The softmax function converts a vector of real numbers into probabilities. For your purposes, you can either define the probs argument or the logits argument to parameterize the Categorical distribution. That is, either vector will define how frequently you tend to sample a particular category.

First, probs should be self explanatory: if sampled from, a Categorical distribution will tend to return category 1 with p_1, i.e. the first entry in the probs vector. All elements in this vector are in [0, 1] and the vector sums to 1. (Strictly speaking, all you need is for this vector to be is non-negative, and Torch normalizes it for you.)

logits is a little less clear. You pass in a vector of real numbers (all elements are in [-inf, inf]). Then, Torch converts this into a vector of probabilities via a few operations:

  1. logits = logits - logits.logsumexp() (source)
  2. probs_from_logits = torch.nn.functional.softmax(logits) (source)

How do we interpret this in a probability sense? We can interpret the logits vector as the logarithm of the odds, or log(p / (1-p)). When you apply the softmax function, you recover p exactly.

Why would we use logits over probs? Either way, they get converted into a vector of probabilities. My guess is that logits is useful so that you can remain in the real space for the purposes of optimization stability. Having a network output logits may lead to more useful gradients and numerical stability.

Edit: to answer some of your questions:

  1. If log_softmax is the logarithm of the softmax function, then it is exactly the same as logits - logits.logsumexp() which normalizes the vector for passing into the softmax function. lhat is thus not the output of the softmax function. Given Torch already normalizes with .logsumexp() I suspect that step may be redundant but maybe there’s something I’m missing.
  2. You should be fine passing any real vector as logits, with or without log_softmax(). The neural network should learn to output well-behaving logits vectors.


Great answer. Thank you.