Subsampling ProdLDA Giving Error

Hi, I’m trying to do subsampling of the ELBO in a ProdLDA model, but keep getting the following error: ‘subsample_size does not match len(subsample), 1000 vs 32. Did you accidentally use different subsample_size in the model and guide?’

I was reading that you only need to specify the subsample_size in the guide when there is a plate in both, as the backend will automatically use the same indices in the model. But I tried specifying in both, then in just the model, and then in just the guide… but I keep getting the same error.

The code for the model is below, it’s just the tutorial model. But I can’t figure out how to correctly do subsampling with this because the dataset I’m using is too large to use batch ELBO gradients.

class ProdLDA(nn.Module):
    def __init__(self, vocab_size, num_topics, hidden, dropout):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_topics = num_topics
        self.encoder = Encoder(vocab_size, num_topics, hidden, dropout)
        self.decoder = Decoder(vocab_size, num_topics, dropout)
    
    def model(self, docs):
        pyro.module('decoder', self.decoder)
        with pyro.plate('documents', docs.shape[0]):
            logtheta_loc = docs.new_zeros((docs.shape[0], self.num_topics))
            logtheta_scale = docs.new_ones((docs.shape[0], self.num_topics))
            logtheta = pyro.sample('logtheta', dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
            theta = F.softmax(logtheta, -1)
            count_param = self.decoder(theta)
            total_count = int(docs.sum(-1).max())
            pyro.sample('obs', dist.Multinomial(total_count, count_param), obs=docs)
    
    def guide(self, docs):
        pyro.module('encoder', self.encoder)
        with pyro.plate('documents', docs.shape[0], subsample_size=1000) as ind:
            logtheta_loc, logtheta_scale = self.encoder(docs.index_select(0, ind))
            logtheta = pyro.sample('logtheta', dist.Normal(logtheta_loc, logtheta_scale).to_event(1))

I think you’ll need to use index_select in the model, something like

  def model(self, docs):
      pyro.module('decoder', self.decoder)
-     with pyro.plate('documents', docs.shape[0]):
+     with pyro.plate('documents', docs.shape[0]) as ind:
+         docs = docs.index_select(0, ind)
          logtheta_loc = docs.new_zeros((docs.shape[0], self.num_topics))
          logtheta_scale = docs.new_ones((docs.shape[0], self.num_topics))
          logtheta = pyro.sample('logtheta', dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
          theta = F.softmax(logtheta, -1)
          count_param = self.decoder(theta)
          total_count = int(docs.sum(-1).max())
          pyro.sample('obs', dist.Multinomial(total_count, count_param), obs=docs)

Actually, I think I figured out the error… Since I’m training with minibatches from a data_loader that feeds the model, the ELBO is actually computing only over the minibatch if I understand correctly. So I don’t need to subsample the ELBO so much as just reduce the minibatch size it seems.