What does logits= mean in pyro.sample(‘obs’, dist.Categorical(logits= ), obs=)? Is it the input to the softmax function?
For the following example, should I use logits= or not?
The network:
class NN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(NN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, x):
output = self.fc1(x)
output = F.relu(output)
output = self.out(output)
return output
net = NN(28*28, 1024, 10)
The model():
def model(x_data, y_data):
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
# sample a regressor (which also samples w and b)
lifted_reg_model = lifted_module()
lhat = log_softmax(lifted_reg_model(x_data))
pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
In the above, by the definition of logits (the input to the softmax function), lhat is the output of softmax and thus not logits. If so, what keyword should I use in place of logits?
If I change the network to:
class NN(nn.Module):
...
def forward(self, x):
...
output = self.out(output)
output = log_softmax()
return output
and define lhat = lifted_reg_model(x_data) instead in model(), then what should I use in Categorical()? Thanks.