Memory errors in LDA

Hi!

I am very new to Pyro and I am trying to implement LDA in Pyro. I got 32000 documents with 1500 words in each document, so the dimensions for my data will be [1500, 32000].
I have made the following model:

def model(data=None, batch_size=None):
    with pyro.plate("topics", num_topics):
        topic_words = pyro.sample("topic_words", dist.Dirichlet(torch.ones(num_words) / num_words))

    with pyro.plate("documents", num_docs) as ind:
        if data is not None:
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (num_words_per_doc, num_docs)
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(torch.ones(num_topics)/ num_topics))
        with pyro.plate("words", num_words_per_doc):
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics), infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data)

    return topic_words, data

and the guide

pyro.clear_param_store()

def my_local_guide(data=None, batch_size=None):
    topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(num_topics, num_words),
            constraint=constraints.positive)
    with pyro.plate("topics", num_topics):
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
    
    doc_topics_posterior = pyro.param(
            "doc_topics_posterior",
            lambda: torch.ones(num_docs, num_topics),
            constraint=constraints.simplex)
    with pyro.plate("documents", num_docs, batch_size) as ind:
        pyro.sample("doc_topics", dist.Delta(doc_topics_posterior[ind], event_dim=1))
    
guide = AutoGuideList(model)
guide.add(AutoDiagonalNormal(pyro.poutine.block(model, expose=['doc_topics'])))
guide.add(my_local_guide) 

guide = my_local_guide

elbo = TraceEnum_ELBO(max_plate_nesting=3)

optim = ClippedAdam({'llommer': 0.05})
svi = SVI(model, guide, optim, elbo)

# Define the number of optimization steps
n_steps = 750

# do gradient steps
for step in range(n_steps):
    elbo = svi.step(W_torch, batch_size=16)
    if step % 25 == 0:
        #print('.', end='')
        print("[%d] ELBO: %.1f" % (step, elbo))

W_torch is my dataset. But when i run it, i get the error

DefaultCPUAllocator: not enough memory: you tried to allocate 2451600000 bytes. Buy new RAM!

How can I make Pyro run with large data sets without allocating so much ram? I am not sure I understand why i get this error.