Negative TraceEnum_ELBO loss in time-series topic model

I’m trying to build a time-series topic model, which seems like the time-series correlated topic model (CTM). Here is my model and guide code:

def _model(self, include_prior=True):
    """
    Implements the data generation process by pyro model
    """
    # loc_mu_kappa = torch.zeros(self.obs_params["K"], device=self.device) * 0.5
    # scale_mu_kappa = torch.eye(self.obs_params["K"], device=self.device) * 0.05
    # mu_kappa = pyro.sample(f"mu_kappa", dist.MultivariateNormal(loc_mu_kappa, scale_mu_kappa))

    # n_Lambda_kappa = torch.tensor(3., device=self.device) * self.obs_params["K"]
    # V_Lambda_kappa = torch.eye(self.obs_params["K"], device=self.device) * (1. / n_Lambda_kappa)
    # Lambda_kappa = pyro.sample(f"Lambda_kappa", dist.Wishart(df=n_Lambda_kappa, covariance_matrix=V_Lambda_kappa))

    with poutine.mask(mask=include_prior):
        delta_k = pyro.sample(
            "delta_k", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))
        delta_kappa = pyro.sample(
            "delta_kappa", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))
        delta_rho = pyro.sample(
            "delta_rho", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))

        eta = torch.tensor([1 / self.obs_params["V"]] * self.obs_params["V"], device=self.device)
        loc_rho = torch.tensor([.5] * self.obs_params["X"], device=self.device)
        scale_rho = torch.eye(self.obs_params["X"], device=self.device) * 0.05
        loc_alpha = torch.tensor([1.] * self.obs_params["K"], device=self.device)
        scale_alpha = torch.eye(self.obs_params["K"], device=self.device) * 0.05
        loc_beta = torch.tensor([5.] * self.obs_params["K"], device=self.device)
        scale_beta = torch.eye(self.obs_params["K"], device=self.device) * 0.5
        a_tau = torch.tensor(2., device=self.device)
        b_tau = torch.tensor(1., device=self.device)

        # initialize gamma
        gamma = []

        # sample document-topic distributions
        with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
            varphi = pyro.sample(f"varphi", dist.Dirichlet(eta))
            rho = pyro.sample(f"rho", dist.MultivariateNormal(loc_rho, scale_rho))
            alpha = pyro.sample(f"alpha", dist.MultivariateNormal(loc_alpha, scale_alpha))
            beta = pyro.sample(f"beta", dist.MultivariateNormal(loc_beta, scale_beta))
            tau = pyro.sample(f"tau", dist.Gamma(a_tau, b_tau))
            with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
                kappa = pyro.sample(
                    f"kappa", dist.Normal(torch.tensor(1., device=self.device), torch.tensor(.5, device=self.device)))
                for d in range(self.obs_params["document_sequence_length"]):
                    x = self.obs_data["x"][d, ..., ...]  # shape=(I, X)
                    if d == 0:
                        gamma_loc = delta_k + delta_kappa * kappa + delta_rho * (x @ rho.T)  # shape=(I, K)
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                        gamma.append(gamma_d)

                    else:
                        timedelta = self.obs_data["timedelta"][d - 1, ...]
                        time_decay_coef = alpha * (-beta.unsqueeze(-3) * timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
                        last_gamma = gamma[d - 1].unsqueeze(-1)
                        time_decay_term = torch.bmm(time_decay_coef, last_gamma)
                        time_decay_term = time_decay_term.squeeze(-1)
                        gamma_loc = kappa + time_decay_term + x @ rho.T
                        # gamma_loc = kappa + torch.bmm(time_decay_coef, last_gamma).squeeze(-1) + x @ rho.T
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                        gamma.append(gamma_d)

        gamma = torch.stack(gamma)

        theta = F.softmax(gamma, dim=-1)  # shape=(D, I, K), document topic distribution

    if self._uncondition_flag is False:
        obs_words = self.obs_data["text"]  # shape=(N, D, I)
    if self._uncondition_flag is True:
        obs_words = None

    # sample document word
    with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1):  # I
        with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2):  # D
            with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3):  # N
                word_topic = pyro.sample(f"topic_of_each_word",
                                         dist.Categorical(theta),
                                         infer={"enumerate": "parallel"})  # shape=(N, D, I)
                p_word = Vindex(varphi)[word_topic]
                word = pyro.sample(f"words", dist.Categorical(p_word), obs=obs_words)
                if self._uncondition_flag is True:
                    self.obs_data["text"] = word
def _guide(self, use_autoguide=True):
    if use_autoguide:
        return AutoDelta(self._model)
    else:
        # learnable parameters
        delta_k_loc = pyro.param(f"delta_k_loc", lambda: torch.tensor(0.5, device=self.device))
        delta_k_scale = pyro.param(f"delta_k_scale",
                                   lambda: torch.tensor(0.05, device=self.device),
                                   constraint=constraints.greater_than(0.0))
        delta_kappa_loc = pyro.param(f"delta_kappa_loc", lambda: torch.tensor(0.5, device=self.device))
        delta_kappa_scale = pyro.param(f"delta_kappa_scale",
                                       lambda: torch.tensor(0.05, device=self.device),
                                       constraint=constraints.greater_than(0.0))
        delta_rho_loc = pyro.param(f"delta_rho_loc", lambda: torch.tensor(0.5, device=self.device))
        delta_rho_scale = pyro.param(f"delta_rho_scale",
                                     lambda: torch.tensor(0.05, device=self.device),
                                     constraint=constraints.greater_than(0.0))
        eta_v = pyro.param(
            f"eta_v",
            lambda: torch.tensor([1 / self.obs_params["V"]] * self.obs_params["V"], device=self.device),
            constraint=constraints.simplex)
        loc_rho = pyro.param(f"loc_rho", lambda: torch.tensor([.5] * self.obs_params["X"], device=self.device))
        scale_rho = pyro.param(f"scale_rho",
                               lambda: torch.eye(self.obs_params["X"], device=self.device) * 0.05,
                               constraint=constraints.positive)
        loc_alpha = pyro.param(f"loc_alpha",
                               lambda: torch.tensor([1.] * self.obs_params["K"], device=self.device))
        scale_alpha = pyro.param(f"scale_alpha",
                                 lambda: torch.eye(self.obs_params["K"], device=self.device) * 0.05,
                                 constraint=constraints.positive)
        loc_beta = pyro.param(f"loc_beta",
                              lambda: torch.tensor([2.] * self.obs_params["K"], device=self.device),
                              constraint=constraints.greater_than(1.0))
        scale_beta = pyro.param(f"scale_beta",
                                lambda: torch.eye(self.obs_params["K"], device=self.device) * 0.5,
                                constraint=constraints.positive)
        a_tau = pyro.param(f"a_tau",
                           lambda: torch.tensor(1., device=self.device),
                           constraint=constraints.greater_than(0.0))
        b_tau = pyro.param(f"b_tau",
                           lambda: torch.tensor(1., device=self.device),
                           constraint=constraints.greater_than(0.0))
        loc_kappa = pyro.param(f"loc_kappa", lambda: torch.tensor(0., device=self.device))
        scale_kappa = pyro.param(f"scale_kappa",
                                 lambda: torch.tensor(0.05, device=self.device),
                                 constraint=constraints.greater_than(0.0))

        # rvs
        delta_k_q = pyro.sample(f"delta_k", dist.Normal(delta_k_loc, delta_k_scale))
        delta_kappa_q = pyro.sample(f"delta_kappa", dist.Normal(delta_kappa_loc, delta_kappa_scale))
        delta_rho_q = pyro.sample(f"delta_rho", dist.Normal(delta_rho_loc, delta_rho_scale))

        gamma = []

        with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
            varphi_q = pyro.sample(f"varphi", dist.Dirichlet(eta_v))
            rho_q = pyro.sample(f"rho", dist.MultivariateNormal(loc_rho, scale_rho))
            alpha_q = pyro.sample(f"alpha", dist.MultivariateNormal(loc_alpha, scale_alpha))
            beta_q = pyro.sample(f"beta", dist.MultivariateNormal(loc_beta, scale_beta))
            tau_q = pyro.sample(f"tau", dist.Gamma(a_tau, b_tau))
            with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
                kappa_q = pyro.sample(f"kappa", dist.Normal(loc_kappa, scale_kappa))
                for d in range(self.obs_params["document_sequence_length"]):
                    x = self.obs_data["x"][d, ..., ...]  # shape=(I, X)
                    if d == 0:
                        gamma_loc = delta_k_q + delta_kappa_q * kappa_q + delta_rho_q * (x @ rho_q.T
                                                                                         )  # shape=(I, K)
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau_q))
                        gamma.append(gamma_d)
                    else:
                        timedelta = self.obs_data["timedelta"][d - 1, ...]
                        time_decay_coef = alpha_q * (-beta_q.unsqueeze(-3) *
                                                     timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
                        last_gamma = gamma[d - 1].unsqueeze(-1)
                        time_decay_term = torch.bmm(time_decay_coef, last_gamma)
                        time_decay_term = time_decay_term.squeeze(-1)
                        gamma_loc = kappa_q + time_decay_term + x @ rho_q.T
                        # gamma_loc = kappa_q + torch.bmm(time_decay_term, last_gamma).squeeze(-1) + x @ rho_q.T
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau_q))
                        gamma.append(gamma_d)

        gamma = torch.stack(gamma)

        theta = F.softmax(gamma, dim=-1)

        with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1):  # I
            with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2):  # D
                with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3):  # N
                    word_topic = pyro.sample(f"topic_of_each_word",
                                             dist.Categorical(theta),
                                             infer={"enumerate": "parallel"})  # shape=(N, D, I)
                    p_word = Vindex(varphi_q)[word_topic]
                    word = pyro.sample(f"words", dist.Categorical(p_word))

When I tried to train this model, it looked good initially. But after a few iterations, the loss turns to a negative value and doesn’t seem to be converging. I used the pouting.uncondition() func to reuse the pyro model to generate the simulation data to validate my model. Now I’m confused. I don’t know if there is a problem with the generated simulation data or if there is an error in my code.

[2022-09-19 19:33:29 INFO]: On iteration 617, loss = 108.1953125
[2022-09-19 19:33:30 INFO]: On iteration 618, loss = 148.8046875
[2022-09-19 19:33:30 INFO]: On iteration 619, loss = 136.1484375
[2022-09-19 19:33:30 INFO]: On iteration 620, loss = 76.56640625
[2022-09-19 19:33:30 INFO]: On iteration 621, loss = 117.4921875
[2022-09-19 19:33:30 INFO]: On iteration 622, loss = 89.83984375
[2022-09-19 19:33:30 INFO]: On iteration 623, loss = 102.1171875
[2022-09-19 19:33:30 INFO]: On iteration 624, loss = 44.41796875
[2022-09-19 19:33:31 INFO]: On iteration 625, loss = 51.1875
[2022-09-19 19:33:31 INFO]: On iteration 626, loss = 112.4296875
[2022-09-19 19:33:31 INFO]: On iteration 627, loss = 73.48828125
[2022-09-19 19:33:31 INFO]: On iteration 628, loss = 43.6484375
[2022-09-19 19:33:31 INFO]: On iteration 629, loss = 51.51171875
[2022-09-19 19:33:31 INFO]: On iteration 630, loss = 6.7734375
[2022-09-19 19:33:32 INFO]: On iteration 631, loss = 24.2265625
[2022-09-19 19:33:32 INFO]: On iteration 632, loss = 78.203125
[2022-09-19 19:33:32 INFO]: On iteration 633, loss = -18.22265625
[2022-09-19 19:33:32 INFO]: On iteration 634, loss = -36.3828125
[2022-09-19 19:33:32 INFO]: On iteration 635, loss = -10.390625
[2022-09-19 19:33:32 INFO]: On iteration 636, loss = -54.13671875
[2022-09-19 19:33:32 INFO]: On iteration 637, loss = -52.375
[2022-09-19 19:33:33 INFO]: On iteration 638, loss = -20.8984375
[2022-09-19 19:33:33 INFO]: On iteration 639, loss = -10.453125
[2022-09-19 19:33:33 INFO]: On iteration 640, loss = -97.80859375
[2022-09-19 19:33:33 INFO]: On iteration 641, loss = -41.1796875
[2022-09-19 19:33:33 INFO]: On iteration 642, loss = -17.46875
[2022-09-19 19:33:33 INFO]: On iteration 643, loss = -35.69921875
[2022-09-19 19:33:34 INFO]: On iteration 644, loss = -126.4375

I haven’t looked at your model, but in general:

  • It’s perfectly reasonable for the ELBO loss to be negative if you have any continuous variables. The ELBO loss term for a continuous variable is a negative log-density. Depending on units, the log density can be arbitrarily shifted up or down. For example switching from inches to meters for some parameter will shift by log(inches/meter).
  • When using Pyro’s SVI, the ELBO loss is a stochastic estimate. If you want a more accurate estimate for evaluation, you can set num_particles to a large number, but that would slow down learning.
1 Like

Thanks for your reply! This problem has been solved. The reason for the negative loss is the parameter of the Dirichlet distribution in this model. I set up a symmetric parameter for the Dirichlet distribution when generating the simulation data, which seems hard to learn via SVI, and got a diverging trend in the loss plot. I modified this parameter and got the correct loss and now my model works well!

1 Like