I think he means that a dataloader does the subsampling/batching before you input the data into the model. This could potentially be seen as better because the linear layers in the network will only be operating on the batches. In contrast, if you use subsampling in the plate, your linear layers already performed computations on the full dataset and your plate is then subsampling at this stage when doing its likelihood scoring.
Although the dataloader could be more computationally efficient, your likelihood will be assuming the batch is the full dataset when constructing the ELBO loss, so it isn’t properly scaled (i.e., it’s giving your priors too much weight and your data not enough weight); hence you need to poutine.scale the likelihood (scale up) if you go this route.
I do happen to have a working Bayesian NN below that uses subsampling. I don’t know why the above poster’s version doesn’t work, but this version below works. However, it may require additional tuning to get optimal performance… I’ve noticed when many talk about Bayesian deep learning, they don’t necessarily mean just putting priors on the parameters of a neural network (that would be the simplest approach I think). Although that can give good model performance in some areas, one shouldn’t assume it automatically means you’re correctly exploring/capturing the full posterior of a giant neural network (admittedly, this is an active research area). Deep kernel learning, SWAG/MultiSWAG, etc. also fall under Bayesian deep learning and can get really good performance without a bunch of tuning.
class Bayesian_Network(PyroModule):
def __init__(self, in_size, out_size):
super().__init__()
# Neural network layers (converts nn.Modules to PyroModules).
self.fc1 = PyroModule[nn.Linear](in_size, 100)
self.fc2 = PyroModule[nn.Linear](100, 150)
self.fc3 = PyroModule[nn.Linear](150, 100)
self.fc4 = PyroModule[nn.Linear](100, out_size)
# Priors of parameters (replaces nn.Parameters with PyroSamples).
self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([100, in_size]).to_event(2))
self.fc1.bias = PyroSample(dist.Normal(0., 10.).expand([100]).to_event(1))
self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([150, 100]).to_event(2))
self.fc2.bias = PyroSample(dist.Normal(0., 10.).expand([150]).to_event(1))
self.fc3.weight = PyroSample(dist.Normal(0., 1.).expand([100, 150]).to_event(2))
self.fc3.bias = PyroSample(dist.Normal(0., 10.).expand([100]).to_event(1))
self.fc4.weight = PyroSample(dist.Normal(0., 1.).expand([out_size, 100]).to_event(2))
self.fc4.bias = PyroSample(dist.Normal(0., 10.).expand([out_size]).to_event(1))
def forward(self, x, y=None):
# Neural network computation.
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
mean = self.fc4(x).squeeze(-1) # squeeze() makes `mean` 1D (instead of 2D with rightmost dim having size 1)
# Prior of observation sigma.
sigma = pyro.sample('sigma', dist.Uniform(0., 10.))
# Likelihood.
with pyro.plate('data', x.shape[0], subsample_size=1000) as ind:
obs = pyro.sample('obs', dist.Normal(mean.index_select(0, ind), sigma), obs=y.index_select(0, ind))
return mean
# Train model.
pyro.clear_param_store()
bayesian_network = Bayesian_Network(5, 1)
guide = AutoNormal(bayesian_network)
optimizer = pyro.optim.Adam({'lr': 0.01})
svi = SVI(bayesian_network, guide, optimizer, Trace_ELBO())
for step in range(501):
loss = svi.step(x, y) / y.numel()
if step % 100 == 0:
print(f"Step {step}, loss = {loss}")