Hi!
I am very new to Pyro and I am trying to implement LDA in Pyro. I got 32000 documents with 1500 words in each document, so the dimensions for my data will be [1500, 32000].
I have made the following model:
def model(data=None, batch_size=None):
with pyro.plate("topics", num_topics):
topic_words = pyro.sample("topic_words", dist.Dirichlet(torch.ones(num_words) / num_words))
with pyro.plate("documents", num_docs) as ind:
if data is not None:
with pyro.util.ignore_jit_warnings():
assert data.shape == (num_words_per_doc, num_docs)
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(torch.ones(num_topics)/ num_topics))
with pyro.plate("words", num_words_per_doc):
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics), infer={"enumerate": "parallel"})
data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data)
return topic_words, data
and the guide
pyro.clear_param_store()
def my_local_guide(data=None, batch_size=None):
topic_words_posterior = pyro.param(
"topic_words_posterior",
lambda: torch.ones(num_topics, num_words),
constraint=constraints.positive)
with pyro.plate("topics", num_topics):
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
doc_topics_posterior = pyro.param(
"doc_topics_posterior",
lambda: torch.ones(num_docs, num_topics),
constraint=constraints.simplex)
with pyro.plate("documents", num_docs, batch_size) as ind:
pyro.sample("doc_topics", dist.Delta(doc_topics_posterior[ind], event_dim=1))
guide = AutoGuideList(model)
guide.add(AutoDiagonalNormal(pyro.poutine.block(model, expose=['doc_topics'])))
guide.add(my_local_guide)
guide = my_local_guide
elbo = TraceEnum_ELBO(max_plate_nesting=3)
optim = ClippedAdam({'llommer': 0.05})
svi = SVI(model, guide, optim, elbo)
# Define the number of optimization steps
n_steps = 750
# do gradient steps
for step in range(n_steps):
elbo = svi.step(W_torch, batch_size=16)
if step % 25 == 0:
#print('.', end='')
print("[%d] ELBO: %.1f" % (step, elbo))
W_torch is my dataset. But when i run it, i get the error
DefaultCPUAllocator: not enough memory: you tried to allocate 2451600000 bytes. Buy new RAM!
How can I make Pyro run with large data sets without allocating so much ram? I am not sure I understand why i get this error.