RuntimeError about autograd when building a time series probabilistic model

I’m trying to build a time series probabilistic model based on LDA using Pyro, I’m almost done building the model. I used a customed guide. But when I run my code, I got RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:

Here is the corresponding code about this error (Main structure of the dynamic part), the most relevant code to this error is time_decay_term = torch.bmm(time_decay_coef, last_gamma_unsq), I’m so confused now and don’t know how to fix this bug:

gamma = torch.empty((self.obs_params["document_sequence_length"], self.obs_params["I"], self.obs_params["K"]),
                            device=self.device)

# sample document-topic distributions
with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
    varphi = pyro.sample(f"varphi", dist.Dirichlet(eta))
    rho = pyro.sample(f"rho", dist.MultivariateNormal(loc_rho, scale_rho))
    alpha = pyro.sample(f"alpha", dist.MultivariateNormal(loc_alpha, scale_alpha))
    beta = pyro.sample(f"beta", dist.MultivariateNormal(loc_beta, scale_beta))
    tau = pyro.sample(f"tau", dist.Gamma(a_tau, b_tau))
    with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
        kappa = pyro.sample(
            f"kappa", dist.Normal(torch.tensor(0., device=self.device), torch.tensor(1., device=self.device)))
        for d in range(self.obs_params["document_sequence_length"]):
            x = self.obs_data["x"].permute(1, 0, 2)[d, ..., ...]  # shape=(I, X)
            if d == 0:
                gamma_loc = delta_k + delta_kappa * kappa + delta_rho * (x @ rho.T)  # shape=(I, K)
                gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                gamma[d, ..., ...] = gamma_d

            else:
                timedelta = self.obs_data["timedelta"].permute(1, 0)[d - 1, ...]
                time_decay_coef = alpha * (-beta.unsqueeze(-3) * timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
                last_gamma = gamma[d - 1, ..., ...]
                last_gamma_unsq = last_gamma.unsqueeze(-1)
                time_decay_term = torch.bmm(time_decay_coef, last_gamma_unsq)
                time_decay_term = time_decay_term.squeeze(-1)
                gamma_loc = kappa + time_decay_term + x @ rho.T
                # gamma_loc = kappa + torch.bmm(time_decay_coef, last_gamma).squeeze(-1) + x @ rho.T
                gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                gamma[d, ..., ...] = gamma_d

The problem has been solved. gamma_d should be reserved in a list or tuple rather than a tensor, which leads to the auto-grad mechanism failing after the gamma tensor is changed.