Subsampling in two related plates

I’m trying to implement a dynamic-style topic model. Here is my model code.

def _model(self):
    """
    (Core) Implements the `model` structure of a Pyro model
    """
    # subsample size
    if self._subsample_size is None:
        user_subsample_size = self.obs_params["I"]
    else:
        user_subsample_size = self._subsample_size["I"]

    delta_k = pyro.sample("delta_k",
                          dist.Normal(self._model_prior["delta_k_loc"], self._model_prior["delta_k_scale"]))
    delta_kappa = pyro.sample(
        "delta_kappa", dist.Normal(self._model_prior["delta_kappa_loc"], self._model_prior["delta_kappa_scale"]))
    delta_rho = pyro.sample("delta_rho",
                            dist.Normal(self._model_prior["delta_rho_loc"], self._model_prior["delta_rho_scale"]))
    # initialize gamma
    gamma = []
    # sample document-topic distributions
    with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
        varphi = pyro.sample(f"varphi", dist.Dirichlet(self._model_prior["eta"]))
        rho = pyro.sample(f"rho",
                          dist.MultivariateNormal(self._model_prior["rho_loc"], self._model_prior["rho_scale"]))
        alpha = pyro.sample(
            f"alpha", dist.MultivariateNormal(self._model_prior["alpha_loc"], self._model_prior["alpha_scale"]))
        beta = pyro.sample(
            f"beta", dist.MultivariateNormal(self._model_prior["beta_loc"], self._model_prior["beta_scale"]))
        tau = pyro.sample(f"tau", dist.Gamma(self._model_prior["tau_a"], self._model_prior["tau_b"]))
        with pyro.plate("user_plate", self.obs_params["I"], subsample_size=user_subsample_size,
                        dim=-2) as u_ind_model:
            u_ind_model = u_ind_model.to(self.device)
            kappa = pyro.sample(f"kappa",
                                dist.Normal(self._model_prior["kappa_loc"], self._model_prior["kappa_scale"]))
            for d in range(self.obs_params["document_sequence_length"]):
                x = self.obs_data["x"][d, ..., ...].index_select(0, u_ind_model)  # shape=(I, X)
                if d == 0:
                    gamma_loc = delta_k + delta_kappa * kappa + delta_rho * (x @ rho.T)  # shape=(I, K)
                    gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                    gamma.append(gamma_d)

                else:
                    timedelta = self.obs_data["timedelta"][d - 1, ...].index_select(0, u_ind_model)
                    time_decay_coef = alpha * (-beta.unsqueeze(-3) * timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
                    last_gamma = gamma[d - 1].unsqueeze(-1)
                    time_decay_term = torch.bmm(time_decay_coef, last_gamma)
                    time_decay_term = time_decay_term.squeeze(-1)
                    gamma_loc = kappa + time_decay_term + x @ rho.T
                    gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                    gamma.append(gamma_d)

    gamma = torch.stack(gamma)
    theta = F.softmax(gamma, dim=-1)  # shape=(D, I, K), document topic distribution

    if self._uncondition_flag is False:
        obs_words = self.obs_data["text"]  # shape=(N, D, I)
    if self._uncondition_flag is True:
        obs_words = None

    # sample document word
    with pyro.plate("user_plate_2", self.obs_params["I"], subsample_size=user_subsample_size, dim=-1):  # I
        with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2):  # D
            with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3):  # N
                word_topic = pyro.sample(f"topic_of_each_word",
                                         dist.Categorical(theta),
                                         infer={"enumerate": "parallel"})  # shape=(N, D, I)
                p_word = Vindex(varphi)[word_topic]
                if obs_words is None:
                    word = pyro.sample(f"words", dist.Categorical(p_word), obs=obs_words)
                else:
                    word = pyro.sample(f"words",
                                       dist.Categorical(p_word),
                                       obs=obs_words.index_select(2, u_ind_model))
                if self._uncondition_flag is True:
                    self.obs_data["text"] = word

To draw topic proportion, I used subsampling for user_plate before # sample document word. I’m wondering if I need to continue to use subsampling in the procedure of draw word, or just need to set the size to subsampling size. The corresponding code is:

with pyro.plate("user_plate", self.obs_params["I"], subsample_size=user_subsample_size,
                dim=-2) as u_ind_model:
# sample document word
with pyro.plate("user_plate_2", self.obs_params["I"], subsample_size=user_subsample_size, dim=-1):  # I