Mini-batch training of SVI models

Are the two ways of training SVI models equivalent?

The first one, as is used in pyro tutorial:

# mini-batch logic defined in the model function with pyro.plate
svi = pyro.infer.SVI(model, guide, optim, loss = pyro.infer.Trace_ELBO())
for i in range(num_iters):
    svi.step(X, y)

The second one, which is more common in pytorch:

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(X, y)
loader  = DataLoader(dataset, batch_size = batch_size, shuffle = True)
# No mini-batch logic in model
svi     = pyro.infer.SVI(model, guide, optim, loss = pyro.infer.Trace_ELBO())
for epoch in range(num_epoch):
    for bx, by in loader:
        svi.step(bx, by)

how you implement the in-model version?

Like this:

def model(X, y):
    num_x  = X.shape[0]
    priors = dict()
    for n, p in nn.named_parameters(): # nn: an nn.Module neural network
        priors[n] = pyro.distributions.Normal(loc = torch.zeros_like(p), scale = torch.ones_like(p)).to_event(1)

    lifted_module    = pyro.random_module("module", nn, priors)
    lifted_reg_model = lifted_module()
    with pyro.plate("map", len(X), subsample_size = min(num_x, batch_size)) as ind:
        pred = lifted_reg_model(X[ind]).squeeze(-1)
        pyro.sample("obs", pyro.distributions.Normal(pred, 1e-2), obs = y[ind])

From my view, subsample_size is finally used to scale your likelihood scores, i.e., your final elbo value will be the one * len(X) / subsample_size. If this is not right, please correct me.

Though this might be a kind of stochastic optimization, but it needs you load all your samples at your memory. I guess when you use GPU, you might even need load all the samples into GPU? If that is true, it will cost too much memory in GPU.

On the other hand, If you use the stochastic optimization in pytorch way, I think, you can reduce your memory in GPU, you only need load partial samples into GPU at one iteration. Also, the model will not scale the elbo score since it only sees the data you feed.

The above is my understanding, and hope someone can correct me.

the subsample_size kwarg in plate scales the likelihood terms in your elbo accordingly. if you want to minibatch manually you can use poutine.scale.

2 Likes

In the Variational Autoencoder tutorial, the training code is like this:


def train(svi, train_loader, use_cuda=False):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x, _ in train_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

It seems that manual mini-batch is used through the train_loader, why is there no scaling of the likelihood term?

this is a special case since there are only local random variables in the model. the subsampling scale factor scales all the variables by the same amount, so it is dropped. in a model where you have a mix of both global and local variables, you would need to scale the “local” likelihood (eq 10)

1 Like

I think its also not the same, cause in 2) you will be iterating over the entire dataset at each epoch (cause the shuffling is per epoch), while in 1) there is no guarantee that you will do so. Right?

1 Like