Mini Batching in Pyro

I am trying to understand how Minibatching works in pyro. For that I have tried implementing mini batching (as described here) in the linear regression example from the Regression example.

I am breaking the data to BATCH_LEN chunks and feeding it to the model. The model has been informed (via pyro.plate statement in the forward function) that the length of the whole data is DATA_LEN and the length of the batch is BATCH_LEN

My code is as follows:

class BayesianRegression(PyroModule):`
    def __init__(self, in_features, out_features, DATA_LEN, BATCH_LEN):
        '''
        DATA_LEN: length of the complete training data
        BATCH_LEN: Batch size 
        '''
        super().__init__()
        self.BATCH_LEN = BATCH_LEN
        self.DATA_LEN = DATA_LEN
        self.linear = PyroModule[nn.Linear](in_features, out_features)       
        self.linear.weight = PyroSample(dist.Normal(0., 1.)
                                   .expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
    def forward(self, x, y=None):
        ''' 
        x ,y  will be   BATCH_LEN long
        '''
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", size= self.DATA_LEN, subsample_size = self.BATCH_LEN):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

num_epochs = 1500
batch_len = 34
data_len = len(x_data)
number_of_batches = data_len/batch_len # == 17

model = BayesianRegression(3, 1, data_len, batch_len)
guide = AutoDiagonalNormal(model)

adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

pyro.clear_param_store()
for j in range(num_epochs):
    loss=0
    for i in list(range(0,data_len,batch_len)):
      x_batch = (x_data[i:(i+batch_len)])
      y_batch = y_data[i:(i+batch_len)]
      loss += svi.step(x_batch, y_batch)
    if j % 100 == 0:
        print("[EPOCH LOSS %04d] loss: %.4f" % (j + 1, loss / len(data)))

Just wanted to check with you:

  • Is this the correct way to feed minibatches?

  • When run the ELBO hardly changes after 1500 epochs. What could be the issue here?

    [EPOCH LOSS 0001] loss: 7.4139
    [EPOCH LOSS 0101] loss: 7.2962
    [EPOCH LOSS 0201] loss: 7.3909
    [EPOCH LOSS 0301] loss: 7.2785
    [EPOCH LOSS 0401] loss: 7.3260
    [EPOCH LOSS 0501] loss: 7.2425
    [EPOCH LOSS 0601] loss: 7.3000
    [EPOCH LOSS 0701] loss: 7.3091
    [EPOCH LOSS 0801] loss: 7.2648
    [EPOCH LOSS 0901] loss: 7.3907
    [EPOCH LOSS 1001] loss: 7.4073
    [EPOCH LOSS 1101] loss: 7.3119
    [EPOCH LOSS 1201] loss: 7.3070
    [EPOCH LOSS 1301] loss: 7.3720
    [EPOCH LOSS 1401] loss: 7.4121

Many Thanks in advance!

because you put a uniform prior on sigma and use an autoguide sigma is probably initialized to a very high value like 25. that may be causing the bad optimization you seem to be seeing. you’ll probably get better results if you turn sigma into a param or use a non-uniform prior like LogNormal(0, 2)

Thanks for the suggestions @martinjankowiak !
My original query was about the batching… But after your reply I looked into your suggestions and have further questions. Let me list all my queries here:

  1. My Primary query - Am I doing mini batching right?
    The tutorial seems to assume the full data in memory. and it is subsampling from that:
   for i in pyro.plate("data_loop", len(data), subsample_size=5):
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])  

Practically, the full data may not fit in memory and we would like to send in batches to the function - and thats what I have done in my first post. Does it look right to you?

  1. I tried your suggestion of LogNormal(0,2) vs Uniform(0,2). Both converge to around the same value. eg: with sigma = pyro.sample("sigma", dist.Uniform(0., 2.)):
[EPOCH LOSS 0001] loss: 67.9620
...
[EPOCH LOSS 1401] loss: 7.1958

with sigma = pyro.sample("sigma", dist.LogNormal(0., 2.)):

[EPOCH LOSS 0001] loss: 49.8630
...
[EPOCH LOSS 1401] loss: 7.2040

But I observed something: Lets say data size=N and batch size=M. (in the above case N=170, M=34, N/M=5)
The loss (say at the 1401th epoch) is N/M times the loss with batch_Size=N (as can be seen in the Regression Example):

[iteration 1401] loss: 1.4581

1.4581 * 5 ~ 7.2

This doesn’t look right as I thought the purpose of providing len(data), subsample_size=5 in the plate statement was to adjust for this. Am I understanding this correctly?

Here is a colab with the code.

Hi @VSDV, I am not an expert, but let’s see if I can help.

    for i in list(range(0,data_len,batch_len)):
      x_batch = (x_data[i:(i+batch_len)])
      y_batch = y_data[i:(i+batch_len)]
      loss += svi.step(x_batch, y_batch)

In the part of the code above, you prepare the mini batches manually and feed them as data to the guide and the model (via svi.step). So there is no need in the data plate to choose a random subsample. The size of the plate is always self.BATCH_LEN and it corresponds to the size of the data that you pass to the function.

If you want to pass all your data as mini batches, you could just remove subsample_size = self.BATCH_LEN from the plate, and possibly replace size = self.DATA_LEN by size = len(y) or whichever shape dimension of y corresponds to your batches (in case that a batch is a bit shorter).

If you want to select a random subsample every time you call the function, you can pass x_data and y_data instead of x_batch and y_batch to svi.step. Then inside the data plate you need to subsample y as well. Otherwise you have a size mismatch between the plate and the observation y. Typically you would do it as shown below.

        with pyro.plate("data", size= self.DATA_LEN, subsample_size = self.BATCH_LEN) as ind:
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y[ind])

The as ind idiom is what allows you to get the indices that were randomly selected. If y is multidimensional, you may have to select some specific dimensions so the code would look a little different, but the idea is the same.

Many thanks for your reply @gui11aume! :slight_smile:

So what you are suggesting would look like this:

with pyro.plate("data", size= self.BATCH_LEN):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)  # here, len(y) = BATCH_LEN.

But as per my understanding of the SVI II chapter, especially this equation and the subsequent discussion, wont pyro.plate need N and M as inputs to scale the log-likelihood by N/M?
(N being self.DATA_LEN and M being self.BATCH_LEN)