I built a time-series topic model with 3 nested plates. When the dataset is small-scale, this model works well. But when I tried to extend the model to a large-scale dataset, I met the GPU out-of-memory
error. I think this error is very relative to the nested plates:
# sample document word
with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1): # I
with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2): # D
with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3): # N
word_topic = pyro.sample(f"topic_of_each_word",
dist.Categorical(theta),
infer={"enumerate": "parallel"}) # shape=(N, D, I)
p_word = Vindex(varphi)[word_topic]
word = pyro.sample(f"words", dist.Categorical(p_word), obs=obs_words)
if self._uncondition_flag is True:
self.obs_data["text"] = word
Can I use for ... in plate
instead of with plate
to decrease the GPU memory allocation in my model? How to implement it correctly?