Subsampling vs batching

Hi pyro experts,
I come from the general ML/DL world, but am a noob in Bayesian methodologies. I read this Tutorial on scaling SVI to large datasets, and had a few questions on subsampling.

  • What is subsampling used for?
  • Can sub sampling be thought of as batching (as understood in frequentist ML) in the context of a large dataset?

Hi @Gandalf,

What is subsampling used for?

Subsampling in Bayesian inference is generally used when either data is too big to fit in memory, or when data is sufficiently homogeneous that it is cheaper to train looking at small minibatches.

Can sub sampling be thought of as batching (as understood in frequentist ML) in the context of a large dataset?

Yes. Actually most of the time in big models, we manually batch data just as one would do when training a neural net, and pass that in to Pyro as an explicit subsample argument to pyro.plate. However in simple models, pyro.plate provides automatic random subsampling of input data.

The one big difference between Bayesian subsampling and generic ML subsampling is that in Bayesian statistics data competes with a prior, and subsampling needs to carefully weight minibatches so that they can precisely compete with the prior, basically scaling likelihood by full_size / batch_size so the likelihood can match the prior weight.

Thanks - this makes it very clear on theoretical differences.

I am sorry, need some clarification on the above quote. By big models I suppose you mean large Neural networks. Shouldn’t it be data size dependent?

Shouldn’t it be data size dependent?

You’re right, I’m was being a little sloppy there :slightly_smiling_face: But I do believe big data needs big models, so I generally think of the two as synonymous in Bayesian contexts. Anyway what matters most is memory footprint and op count per step, and minibatching helps with both.

@fritzo I completely agree with you on that correlation! :grinning:

one last question on this. The plate api seems to suggest two ways to mini-batch. One is using the subsample_size=mini batch size construct and another using the subsample= construct. Are they equivalent? What is subsample supposed to take?

The two methods are different in that (1) pyro.plate(..., subsample_size=_) is simpler and lets Pyro choose the minibatch, but assumes your entire data fits in memory. In this version you’ll need to slice down data tensors manually as described in the tensor shapes tutorial; whereas (2) pyro.plate(..., subsample=_) assumes you are already manually subsampling the data, which can be more useful if e.g. you are using a PyTorch data loader or you are dynamically loading or creating minibatches from a dataset that wouldn’t fit in memory.
What is subsample supposed to take?

Strictly speaking it should take a LongTensor of indices into your full dataset, indicating which batch items are selected, i.e. it should be the same as the ind tensor as returned by automatic subsampling:

with pyro.plate("my_plate", 100, 10) as ind:
   # ind is the same kind of thing you'd pass as subsample

But in almost all cases you can just pass any old tensor of the correct length, so I often simply pass the subsampled data tensor itself (rather than an index tensor into the data). Basically you just need len(subsample) to be correct to ensure that the scale factor is correctly computed.

Thanks @fritzo …This helps a lot!

This is a very very useful post. There is a lot of confusion regarding batching in pyro. @fritzo thanks for the clarifications.

As a follow up to my post here, will the following be the correct way to mini batch (just need your stamp of approval :grimacing:)

def forward(self, x, y=None):
        ''' 
        x ,y  will be pytorch tensors BATCH_LEN long
        '''
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
     mean = self.linear(x).squeeze(-1)
     #-with pyro.plate("data", size= self.DATA_LEN, subsample_size = self.BATCH_LEN):
     with pyro.plate("data", size= self.DATA_LEN, subsample=y):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean