This is a very very useful post. There is a lot of confusion regarding batching in pyro. @fritzo thanks for the clarifications.
As a follow up to my post here, will the following be the correct way to mini batch (just need your stamp of approval )
def forward(self, x, y=None):
'''
x ,y will be pytorch tensors BATCH_LEN long
'''
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
#-with pyro.plate("data", size= self.DATA_LEN, subsample_size = self.BATCH_LEN):
with pyro.plate("data", size= self.DATA_LEN, subsample=y):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean