Manually written guide vs Auto guide and *** TypeError: 'dict' object is not callable

That helps. But what I saw was that the performance was worse than the manual guide. The manual guide gives an accuracy of 90% with ELBO loss 86, but the auto guide gives only 78% with ELBO loss 93. What could cause the difference in performance? Since here I am sampling the random weights from latent_model which is part model, not guide, does it mean that the weights aren’t using the optimized variational posterior distribution q(w)?

I used the suggested two-part approach in defining the model. The prediction (where trace= is used instead of guide_trace=, which caused an error) is done as:

        lifted_reg_model = poutine.replay(latent_model, trace=tr)()
        yhats.append(log_softmax(lifted_reg_model(xdata)))

Below is the code:

def latent_model():
     fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight),   scale=torch.ones_like(net.fc1.weight))
     fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))
     outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))
     outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias)) 
     priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'out.weight': outw_prior, 'out.bias': outb_prior}
     # lift module parameters to random variables sampled from the priors
     lifted_module = pyro.random_module("module", net, priors)
     return lifted_module()

def model(x_data, y_data):
    # sample a regressor (which also samples w and b)
    lifted_reg_model = latent_model()
    lhat = log_softmax(lifted_reg_model(x_data))
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data) 

optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

num_samples = 10

def predict(xdata):
    tr = poutine.trace(guide).get_trace(xdata, None)
    yhats = []
    for i in range(num_samples):
        lifted_reg_model = poutine.replay(latent_model, trace=tr)()
        yhats.append(log_softmax(lifted_reg_model(xdata)))
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean.detach().numpy(), axis=1) 

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(images.view(-1,28*28))
    total += labels.size(0)
    correct += (predicted == labels.cpu().numpy()).sum().item()
    
print("accuracy: %d %%" % (100 * correct / total))
1 Like