I am trying to understand how Minibatching works in pyro. For that I have tried implementing mini batching (as described here) in the linear regression example from the Regression example.
I am breaking the data to BATCH_LEN chunks and feeding it to the model. The model has been informed (via pyro.plate statement in the forward function) that the length of the whole data is DATA_LEN and the length of the batch is BATCH_LEN
My code is as follows:
class BayesianRegression(PyroModule):`
def __init__(self, in_features, out_features, DATA_LEN, BATCH_LEN):
'''
DATA_LEN: length of the complete training data
BATCH_LEN: Batch size
'''
super().__init__()
self.BATCH_LEN = BATCH_LEN
self.DATA_LEN = DATA_LEN
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.)
.expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
'''
x ,y will be BATCH_LEN long
'''
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(1)
with pyro.plate("data", size= self.DATA_LEN, subsample_size = self.BATCH_LEN):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
num_epochs = 1500
batch_len = 34
data_len = len(x_data)
number_of_batches = data_len/batch_len # == 17
model = BayesianRegression(3, 1, data_len, batch_len)
guide = AutoDiagonalNormal(model)
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())
pyro.clear_param_store()
for j in range(num_epochs):
loss=0
for i in list(range(0,data_len,batch_len)):
x_batch = (x_data[i:(i+batch_len)])
y_batch = y_data[i:(i+batch_len)]
loss += svi.step(x_batch, y_batch)
if j % 100 == 0:
print("[EPOCH LOSS %04d] loss: %.4f" % (j + 1, loss / len(data)))
Just wanted to check with you:

Is this the correct way to feed minibatches?

When run the ELBO hardly changes after 1500 epochs. What could be the issue here?
[EPOCH LOSS 0001] loss: 7.4139
[EPOCH LOSS 0101] loss: 7.2962
[EPOCH LOSS 0201] loss: 7.3909
[EPOCH LOSS 0301] loss: 7.2785
[EPOCH LOSS 0401] loss: 7.3260
[EPOCH LOSS 0501] loss: 7.2425
[EPOCH LOSS 0601] loss: 7.3000
[EPOCH LOSS 0701] loss: 7.3091
[EPOCH LOSS 0801] loss: 7.2648
[EPOCH LOSS 0901] loss: 7.3907
[EPOCH LOSS 1001] loss: 7.4073
[EPOCH LOSS 1101] loss: 7.3119
[EPOCH LOSS 1201] loss: 7.3070
[EPOCH LOSS 1301] loss: 7.3720
[EPOCH LOSS 1401] loss: 7.4121
Many Thanks in advance!