How to decrease the gpu memory usage when large scale nested plates appear in model?

I built a time-series topic model with 3 nested plates. When the dataset is small-scale, this model works well. But when I tried to extend the model to a large-scale dataset, I met the GPU out-of-memory error. I think this error is very relative to the nested plates:

# sample document word
with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1):  # I
    with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2):  # D
        with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3):  # N
            word_topic = pyro.sample(f"topic_of_each_word",
                                     dist.Categorical(theta),
                                     infer={"enumerate": "parallel"})  # shape=(N, D, I)
            p_word = Vindex(varphi)[word_topic]
            word = pyro.sample(f"words", dist.Categorical(p_word), obs=obs_words)
            if self._uncondition_flag is True:
                self.obs_data["text"] = word

Can I use for ... in plate instead of with plate to decrease the GPU memory allocation in my model? How to implement it correctly?

you probably want to do things like

with pyro.plate('my_plate', size=10 ** 5, subsample_size=10 ** 3) as index:

please refer to this tutorial

Thanks for your reply! I’ll try it.