I get errors when trying to run multiple chains in parallel using HMC, and it persists on multiple computers and across reboots. The error is
RuntimeError: received 0 items of ancdata.
See RuntimeError Traceback (most recent call last)<ip - Pastebin.com for the full error.
Any ideas as to what might be wrong? I do inference with
nuts_kernel = NUTS(model, max_tree_depth=5)
posterior = MCMC(nuts_kernel, num_samples=500, warmup_steps=2000, num_chains=5).run(x, y)
and my model looks like
def model(x, y):
prior_std_mean = 1.0
prior_std_var = 0.5
fc1_mean_weight_prior = Normal(loc=torch.zeros_like(net.fc1_mean.weight), scale=prior_std_mean*torch.ones_like(net.fc1_mean.weight))
fc1_mean_bias_prior = Normal(loc=torch.zeros_like(net.fc1_mean.bias), scale=prior_std_mean*torch.ones_like(net.fc1_mean.bias))
fc2_mean_weight_prior = Normal(loc=torch.zeros_like(net.fc2_mean.weight), scale=prior_std_mean*torch.ones_like(net.fc2_mean.weight))
fc2_mean_bias_prior = Normal(loc=torch.zeros_like(net.fc2_mean.bias), scale=prior_std_mean*torch.ones_like(net.fc2_mean.bias))
fc1_var_weight_prior = Normal(loc=torch.zeros_like(net.fc1_var.weight), scale=prior_std_var*torch.ones_like(net.fc1_var.weight))
fc1_var_bias_prior = Normal(loc=torch.zeros_like(net.fc1_var.bias), scale=prior_std_var*torch.ones_like(net.fc1_var.bias))
fc2_var_weight_prior = Normal(loc=torch.zeros_like(net.fc2_var.weight), scale=prior_std_var*torch.ones_like(net.fc2_var.weight))
fc2_var_bias_prior = Normal(loc=torch.zeros_like(net.fc2_var.bias), scale=prior_std_var*torch.ones_like(net.fc2_var.bias))
priors = {"fc1_mean.weight": fc1_mean_weight_prior, "fc1_mean.bias": fc1_mean_bias_prior,
"fc2_mean.weight": fc2_mean_weight_prior, "fc2_mean.bias": fc2_mean_bias_prior,
"fc1_var.weight": fc1_var_weight_prior, "fc1_var.bias": fc1_var_bias_prior,
"fc2_var.weight": fc2_var_weight_prior, "fc2_var.bias": fc2_var_bias_prior}
lifted_module = pyro.random_module("module", net, priors)
sampled_reg_model = lifted_module()
mu, log_sigma_2 = sampled_reg_model(x)
sigma = torch.sqrt(torch.exp(log_sigma_2))
return pyro.sample("obs", pyro.distributions.Normal(mu, sigma), obs=y)
where net is just a regular NN with ReLUs on top of fully-connected layers.
Thank you!