I replaced a manually written guide function (which worked fine) with auto guide by using AutoDiagonalNormal. Then, I got the error
*** TypeError: ‘dict’ object is not callable,
and was not able to figure it out.
The whole code is at here. At the high level,
model:
def model(x_data, y_data):
fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
....
Manually defined guide:
def guide(x_data, y_data):
# First layer weight distribution priors
fc1w_mu = torch.randn_like(net.fc1.weight)
fc1w_sigma = torch.randn_like(net.fc1.weight)
fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
...
Auto guide:
guide = AutoDiagonalNormal(model)
def predict(x):
sampled_models = [guide(None, None) for _ in range(num_samples)]
pdb.set_trace()
yhats = [model(x).data for model in sampled_models] # <---- where error occurs with **auto guide**
mean = torch.mean(torch.stack(yhats), 0)
return np.argmax(mean.numpy(), axis=1)
correct = 0
total = 0
for j, data in enumerate(test_loader):
images, labels = data
predicted = predict(images.view(-1,28*28))
total += labels.size(0)
Why replacing a manually defined guide with an auto guide causes the error? The whole error message is:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-15-67749cc8efa4> in <module>
12 for j, data in enumerate(test_loader):
13 images, labels = data
---> 14 predicted = predict(images.view(-1,28*28))
15 total += labels.size(0)
16 #correct += (predicted == labels).sum().item() # corrected, 12/24/2018
<ipython-input-15-67749cc8efa4> in predict(x)
3 sampled_models = [guide(None, None) for _ in range(num_samples)]
4 pdb.set_trace()
----> 5 yhats = [model(x).data for model in sampled_models]
6 mean = torch.mean(torch.stack(yhats), 0)
7 return np.argmax(mean.numpy(), axis=1)
<ipython-input-15-67749cc8efa4> in <listcomp>(.0)
3 sampled_models = [guide(None, None) for _ in range(num_samples)]
4 pdb.set_trace()
----> 5 yhats = [model(x).data for model in sampled_models]
6 mean = torch.mean(torch.stack(yhats), 0)
7 return np.argmax(mean.numpy(), axis=1)
TypeError: 'dict' object is not callable
Thanks.