How to write a guide for a complex dynamic topic model?

I am using Pyro to implement a complex dynamic topic model. AutoGuide does not seem to work for my model, so I chose to write the guide manually. I wrote the variational distribution of the model, which follows the mean-field assumption (variables with subscripts k,i,d are all local conditional independent random variables).


However, implementing this mean-field variational distribution manually seems to be complicated because, for each local conditional independent random variable, I need to specify an independent variational parameter for it, which leads to a large number of variational parameters in the guide.

def _manual_guide(self):
    """
    (Core) Implements the `guide` structure of a Pyro model

    Returns:
        _type_: _description_
    """
    # subsample size
    if self._subsample_size is None:
        user_subsample_size = self.obs_params["I"]
    else:
        user_subsample_size = self._subsample_size["I"]

    # learnable parameters
    # univariate params
    delta_k_loc = pyro.param(f"delta_k_loc", lambda: self._guide_prior["delta_k_loc"])
    delta_k_scale = pyro.param(f"delta_k_scale",
                               lambda: self._guide_prior["delta_k_scale"],
                               constraint=constraints.greater_than(0.))
    delta_kappa_loc = pyro.param(f"delta_kappa_loc", lambda: self._guide_prior["delta_kappa_loc"])
    delta_kappa_scale = pyro.param(f"delta_kappa_scale",
                                   lambda: self._guide_prior["delta_kappa_scale"],
                                   constraint=constraints.greater_than(0.))
    delta_rho_loc = pyro.param(f"delta_rho_loc", lambda: self._guide_prior["delta_rho_loc"])
    delta_rho_scale = pyro.param(f"delta_rho_scale",
                                 lambda: self._guide_prior["delta_rho_scale"],
                                 constraint=constraints.greater_than(0.))
    mu_mu_kappa = pyro.param("mu_mu_kappa", lambda: self._guide_prior["mu_mu_kappa"])
    Sigma_mu_kappa = pyro.param("Sigma_mu_kappa",
                                lambda: self._guide_prior["Sigma_mu_kappa"],
                                constraint=constraints.positive_definite)
    n_Lambda_kappa = pyro.param("n_Lambda_kappa",
                                lambda: self._guide_prior["n_Lambda_kappa"],
                                constraint=constraints.positive)
    V_Lambda_kappa = pyro.param("V_Lambda_kappa",
                                lambda: self._guide_prior["V_Lambda_kappa"],
                                constraint=constraints.positive_definite)

    # Multivariate params
    kappa_mu = []  # (I, K)
    kappa_Sigma = []  # (I, K)
    gamma_mu = []  # (D, I, K)
    gamma_sigma = []  # (D, I, K)
    eta = []  # (K, V)
    rho_loc = []  # (K, X)
    rho_scale = []  # (K, X, X)
    alpha_loc = []  # (K, K)
    alpha_scale = []  # (K, K, K)
    beta_loc = []  # (K, K)
    beta_scale = []  # (K, K, K)
    tau_a = []  # (K, )
    tau_b = []  # (K, )
    for i in range(self.obs_params["I"]):
        # kappa_mu, kappa_Sigma
        kappa_mu_i = pyro.param(f"kappa_mu_{i}", lambda: self._guide_prior["kappa_mu"])
        kappa_Sigma_i = pyro.param(f"kappa_Sigma_{i}",
                                   lambda: self._guide_prior["kappa_Sigma"],
                                   constraint=constraints.positive_definite)
        kappa_mu.append(kappa_mu_i)
        kappa_Sigma.append(kappa_Sigma_i)

    # gamma_mu, gamma_sigma
    for d in range(self.obs_params["document_sequence_length"]):
        gamma_mu_d = []
        gamma_sigma_d = []
        for i in range(self.obs_params["I"]):
            gamma_mu_id = pyro.param(f"gamma_mu_{i}_{d}", lambda: self._guide_prior["gamma_mu"])  # (K, )
            gamma_sigma_id = pyro.param(f"gamma_Sigma_{i}_{d}",
                                        lambda: self._guide_prior["gamma_Sigma"],
                                        constraint=constraints.positive)  # (K, )
            gamma_mu_d.append(gamma_mu_id)
            gamma_sigma_d.append(gamma_sigma_id)
        gamma_mu.append(gamma_mu_d)
        gamma_sigma.append(gamma_sigma_d)

    for k in range(self.obs_params["K"]):
        eta_k = pyro.param(f"eta_{k}", lambda: self._guide_prior["eta_v"], constraint=constraints.positive)
        eta.append(eta_k)
        rho_loc_k = pyro.param(f"rho_loc_{k}", lambda: self._guide_prior["rho_loc"])
        rho_scale_k = pyro.param(f"rho_scale_{k}",
                                 lambda: self._guide_prior["rho_scale"],
                                 constraint=constraints.positive)
        rho_loc.append(rho_loc_k)
        rho_scale.append(rho_scale_k)
        alpha_loc_k = pyro.param(f"alpha_loc_{k}", lambda: self._guide_prior["alpha_loc"])
        alpha_scale_k = pyro.param(f"alpha_scale_{k}",
                                   lambda: self._guide_prior["alpha_scale"],
                                   constraint=constraints.positive)
        alpha_loc.append(alpha_loc_k)
        alpha_scale.append(alpha_scale_k)
        beta_loc_k = pyro.param(f"beta_loc_{k}",
                                lambda: self._guide_prior["beta_loc"],
                                constraint=constraints.greater_than(1.))
        beta_scale_k = pyro.param(f"scale_beta_{k}",
                                  lambda: self._guide_prior["beta_scale"],
                                  constraint=constraints.positive)
        beta_loc.append(beta_loc_k)
        beta_scale.append(beta_scale_k)
        tau_a_k = pyro.param(f"tau_a_{k}",
                             lambda: self._guide_prior["tau_a"],
                             constraint=constraints.greater_than(0.))
        tau_b_k = pyro.param(f"tau_b_{k}",
                             lambda: self._guide_prior["tau_b"],
                             constraint=constraints.greater_than(0.))
        tau_a.append(tau_a_k)
        tau_b.append(tau_b_k)

    # global RV
    q_delta_k = pyro.sample(f"delta_k", dist.Normal(delta_k_loc, delta_k_scale))
    q_delta_kappa = pyro.sample(f"delta_kappa", dist.Normal(delta_kappa_loc, delta_kappa_scale))
    q_delta_rho = pyro.sample(f"delta_rho", dist.Normal(delta_rho_loc, delta_rho_scale))
    q_mu_kappa = pyro.sample("mu_kappa", dist.MultivariateNormal(mu_mu_kappa, Sigma_mu_kappa))
    q_Lambda_kappa = pyro.sample("Lambda_kappa", dist.Wishart(n_Lambda_kappa, V_Lambda_kappa))

    kappa_mu = torch.stack(kappa_mu)
    kappa_Sigma = torch.stack(kappa_Sigma)
    with pyro.plate("kappa_plate", self.obs_params["I"], dim=-1):
        q_kappa = pyro.sample("kappa", dist.MultivariateNormal(kappa_mu, kappa_Sigma))

    q_gamma = []
    eta = torch.stack(eta)
    rho_loc = torch.stack(rho_loc)
    rho_scale = torch.stack(rho_scale)
    alpha_loc = torch.stack(alpha_loc)
    alpha_scale = torch.stack(alpha_scale)
    beta_loc = torch.stack(beta_loc)
    beta_scale = torch.stack(beta_scale)
    tau_a = torch.stack(tau_a)
    tau_b = torch.stack(tau_b)
    with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
        q_varphi = pyro.sample(f"varphi", dist.Dirichlet(eta))
        q_rho = pyro.sample(f"rho", dist.MultivariateNormal(rho_loc, rho_scale))
        q_alpha = pyro.sample(f"alpha", dist.MultivariateNormal(alpha_loc, alpha_scale))
        q_beta = pyro.sample(f"beta", dist.MultivariateNormal(beta_loc, beta_scale))
        q_tau = pyro.sample(f"tau", dist.Gamma(tau_a, tau_b))
        with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
            for d in range(self.obs_params["document_sequence_length"]):
                q_gamma_d = pyro.sample(f"gamma_{d}",
                                        dist.Normal(torch.stack(gamma_mu[d]),
                                                    torch.stack(gamma_sigma[d])))  # (I, K)
                q_gamma.append(q_gamma_d)

    q_gamma = torch.stack(q_gamma)
    q_theta = F.softmax(q_gamma, dim=-1)

    with pyro.plate("user_plate_2", self.obs_params["I"], subsample_size=user_subsample_size,
                    dim=-1) as u_ind_word_guide:  # I
        u_ind_word_guide = u_ind_word_guide.to(self.device)
        with pyro.plate("document_sequence", self.obs_params["document_sequence_length"], dim=-2):  # D
            with pyro.plate("document", self.obs_params["document_length"], dim=-3):  # N
                word_topic = pyro.sample(f"word_topic",
                                         dist.Categorical(q_theta.index_select(1, u_ind_word_guide)),
                                         infer={"enumerate": "parallels"})  # (N, D, I)

Currently, my model can run successfully, but the results do not seem to be as good as expected. I am confused about whether my understanding of the manually written guide is correct.

The ELBOLoss curve looks good, but the parameters obtained after the update seem to be less accurate.