That helps. But what I saw was that the performance was worse than the manual guide. The manual guide gives an accuracy of 90% with ELBO loss 86, but the auto guide gives only 78% with ELBO loss 93. What could cause the difference in performance? Since here I am sampling the random weights from latent_model
which is part model
, not guide
, does it mean that the weights aren’t using the optimized variational posterior distribution q(w)
?
I used the suggested two-part approach in defining the model
. The prediction (where trace=
is used instead of guide_trace=
, which caused an error) is done as:
lifted_reg_model = poutine.replay(latent_model, trace=tr)()
yhats.append(log_softmax(lifted_reg_model(xdata)))
Below is the code:
def latent_model():
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
return lifted_module()
def model(x_data, y_data):
# sample a regressor (which also samples w and b)
lifted_reg_model = latent_model()
lhat = log_softmax(lifted_reg_model(x_data))
pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
num_samples = 10
def predict(xdata):
tr = poutine.trace(guide).get_trace(xdata, None)
yhats = []
for i in range(num_samples):
lifted_reg_model = poutine.replay(latent_model, trace=tr)()
yhats.append(log_softmax(lifted_reg_model(xdata)))
mean = torch.mean(torch.stack(yhats), 0)
return np.argmax(mean.detach().numpy(), axis=1)
print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
images, labels = data
predicted = predict(images.view(-1,28*28))
total += labels.size(0)
correct += (predicted == labels.cpu().numpy()).sum().item()
print("accuracy: %d %%" % (100 * correct / total))