I came across the same problem in a much simpler context. I wanted to modify your intro to Bayesian linear regression example, so that the posterior mean of the predictive distribution is stored as a deterministic variable, instead of relying on the _RETURN magic incantation. But I get a mysterious extra dim of 1 added. (This also occurs is other examples of predictive, where I want to return stochastic parameters of the model (e.g., linear.weights), so it is not unique to deterministic sites.)
Specifically, here is the model (the only change is marked as ## NEW)
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
mu = pyro.deterministic("mu", mean) ### NEW
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
I run MCMC inference to get the parameter posterior :
pyro.set_rng_seed(1)
model = BayesianRegression(3, 1)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(x_data, y_data)
Finally I compute the predictive posterior:
predictive = Predictive(model, mcmc.get_samples(), return_sites=("obs", "mu", "_RETURN"))
hmc_samples_pred = predictive(x_data)
print(hmc_samples_pred.keys())
print(hmc_samples_pred['obs'].shape)
print(hmc_samples_pred['mu'].shape)
print(hmc_samples_pred['_RETURN'].shape)
This yields the following (note the shape of mu is (S,1,N) instead of (S,N), for reasons that are not clear).
dict_keys(['obs', 'mu', '_RETURN'])
torch.Size([1000, 170])
torch.Size([1000, 1, 170]) ## very weird
torch.Size([1000, 170])
I see some shenanigans about adding an extra dim on line 72 of infer.predictive.py but don’t understand it…