Negative TraceEnum_ELBO loss in time-series topic model

I’m trying to build a time-series topic model, which seems like the time-series correlated topic model (CTM). Here is my model and guide code:

def _model(self, include_prior=True):
    """
    Implements the data generation process by pyro model
    """
    # loc_mu_kappa = torch.zeros(self.obs_params["K"], device=self.device) * 0.5
    # scale_mu_kappa = torch.eye(self.obs_params["K"], device=self.device) * 0.05
    # mu_kappa = pyro.sample(f"mu_kappa", dist.MultivariateNormal(loc_mu_kappa, scale_mu_kappa))

    # n_Lambda_kappa = torch.tensor(3., device=self.device) * self.obs_params["K"]
    # V_Lambda_kappa = torch.eye(self.obs_params["K"], device=self.device) * (1. / n_Lambda_kappa)
    # Lambda_kappa = pyro.sample(f"Lambda_kappa", dist.Wishart(df=n_Lambda_kappa, covariance_matrix=V_Lambda_kappa))

    with poutine.mask(mask=include_prior):
        delta_k = pyro.sample(
            "delta_k", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))
        delta_kappa = pyro.sample(
            "delta_kappa", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))
        delta_rho = pyro.sample(
            "delta_rho", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))

        eta = torch.tensor([1 / self.obs_params["V"]] * self.obs_params["V"], device=self.device)
        loc_rho = torch.tensor([.5] * self.obs_params["X"], device=self.device)
        scale_rho = torch.eye(self.obs_params["X"], device=self.device) * 0.05
        loc_alpha = torch.tensor([1.] * self.obs_params["K"], device=self.device)
        scale_alpha = torch.eye(self.obs_params["K"], device=self.device) * 0.05
        loc_beta = torch.tensor([5.] * self.obs_params["K"], device=self.device)
        scale_beta = torch.eye(self.obs_params["K"], device=self.device) * 0.5
        a_tau = torch.tensor(2., device=self.device)
        b_tau = torch.tensor(1., device=self.device)

        # initialize gamma
        gamma = []

        # sample document-topic distributions
        with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
            varphi = pyro.sample(f"varphi", dist.Dirichlet(eta))
            rho = pyro.sample(f"rho", dist.MultivariateNormal(loc_rho, scale_rho))
            alpha = pyro.sample(f"alpha", dist.MultivariateNormal(loc_alpha, scale_alpha))
            beta = pyro.sample(f"beta", dist.MultivariateNormal(loc_beta, scale_beta))
            tau = pyro.sample(f"tau", dist.Gamma(a_tau, b_tau))
            with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
                kappa = pyro.sample(
                    f"kappa", dist.Normal(torch.tensor(1., device=self.device), torch.tensor(.5, device=self.device)))
                for d in range(self.obs_params["document_sequence_length"]):
                    x = self.obs_data["x"][d, ..., ...]  # shape=(I, X)
                    if d == 0:
                        gamma_loc = delta_k + delta_kappa * kappa + delta_rho * (x @ rho.T)  # shape=(I, K)
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                        gamma.append(gamma_d)

                    else:
                        timedelta = self.obs_data["timedelta"][d - 1, ...]
                        time_decay_coef = alpha * (-beta.unsqueeze(-3) * timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
                        last_gamma = gamma[d - 1].unsqueeze(-1)
                        time_decay_term = torch.bmm(time_decay_coef, last_gamma)
                        time_decay_term = time_decay_term.squeeze(-1)
                        gamma_loc = kappa + time_decay_term + x @ rho.T
                        # gamma_loc = kappa + torch.bmm(time_decay_coef, last_gamma).squeeze(-1) + x @ rho.T
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
                        gamma.append(gamma_d)

        gamma = torch.stack(gamma)

        theta = F.softmax(gamma, dim=-1)  # shape=(D, I, K), document topic distribution

    if self._uncondition_flag is False:
        obs_words = self.obs_data["text"]  # shape=(N, D, I)
    if self._uncondition_flag is True:
        obs_words = None

    # sample document word
    with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1):  # I
        with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2):  # D
            with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3):  # N
                word_topic = pyro.sample(f"topic_of_each_word",
                                         dist.Categorical(theta),
                                         infer={"enumerate": "parallel"})  # shape=(N, D, I)
                p_word = Vindex(varphi)[word_topic]
                word = pyro.sample(f"words", dist.Categorical(p_word), obs=obs_words)
                if self._uncondition_flag is True:
                    self.obs_data["text"] = word
def _guide(self, use_autoguide=True):
    if use_autoguide:
        return AutoDelta(self._model)
    else:
        # learnable parameters
        delta_k_loc = pyro.param(f"delta_k_loc", lambda: torch.tensor(0.5, device=self.device))
        delta_k_scale = pyro.param(f"delta_k_scale",
                                   lambda: torch.tensor(0.05, device=self.device),
                                   constraint=constraints.greater_than(0.0))
        delta_kappa_loc = pyro.param(f"delta_kappa_loc", lambda: torch.tensor(0.5, device=self.device))
        delta_kappa_scale = pyro.param(f"delta_kappa_scale",
                                       lambda: torch.tensor(0.05, device=self.device),
                                       constraint=constraints.greater_than(0.0))
        delta_rho_loc = pyro.param(f"delta_rho_loc", lambda: torch.tensor(0.5, device=self.device))
        delta_rho_scale = pyro.param(f"delta_rho_scale",
                                     lambda: torch.tensor(0.05, device=self.device),
                                     constraint=constraints.greater_than(0.0))
        eta_v = pyro.param(
            f"eta_v",
            lambda: torch.tensor([1 / self.obs_params["V"]] * self.obs_params["V"], device=self.device),
            constraint=constraints.simplex)
        loc_rho = pyro.param(f"loc_rho", lambda: torch.tensor([.5] * self.obs_params["X"], device=self.device))
        scale_rho = pyro.param(f"scale_rho",
                               lambda: torch.eye(self.obs_params["X"], device=self.device) * 0.05,
                               constraint=constraints.positive)
        loc_alpha = pyro.param(f"loc_alpha",
                               lambda: torch.tensor([1.] * self.obs_params["K"], device=self.device))
        scale_alpha = pyro.param(f"scale_alpha",
                                 lambda: torch.eye(self.obs_params["K"], device=self.device) * 0.05,
                                 constraint=constraints.positive)
        loc_beta = pyro.param(f"loc_beta",
                              lambda: torch.tensor([2.] * self.obs_params["K"], device=self.device),
                              constraint=constraints.greater_than(1.0))
        scale_beta = pyro.param(f"scale_beta",
                                lambda: torch.eye(self.obs_params["K"], device=self.device) * 0.5,
                                constraint=constraints.positive)
        a_tau = pyro.param(f"a_tau",
                           lambda: torch.tensor(1., device=self.device),
                           constraint=constraints.greater_than(0.0))
        b_tau = pyro.param(f"b_tau",
                           lambda: torch.tensor(1., device=self.device),
                           constraint=constraints.greater_than(0.0))
        loc_kappa = pyro.param(f"loc_kappa", lambda: torch.tensor(0., device=self.device))
        scale_kappa = pyro.param(f"scale_kappa",
                                 lambda: torch.tensor(0.05, device=self.device),
                                 constraint=constraints.greater_than(0.0))

        # rvs
        delta_k_q = pyro.sample(f"delta_k", dist.Normal(delta_k_loc, delta_k_scale))
        delta_kappa_q = pyro.sample(f"delta_kappa", dist.Normal(delta_kappa_loc, delta_kappa_scale))
        delta_rho_q = pyro.sample(f"delta_rho", dist.Normal(delta_rho_loc, delta_rho_scale))

        gamma = []

        with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
            varphi_q = pyro.sample(f"varphi", dist.Dirichlet(eta_v))
            rho_q = pyro.sample(f"rho", dist.MultivariateNormal(loc_rho, scale_rho))
            alpha_q = pyro.sample(f"alpha", dist.MultivariateNormal(loc_alpha, scale_alpha))
            beta_q = pyro.sample(f"beta", dist.MultivariateNormal(loc_beta, scale_beta))
            tau_q = pyro.sample(f"tau", dist.Gamma(a_tau, b_tau))
            with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
                kappa_q = pyro.sample(f"kappa", dist.Normal(loc_kappa, scale_kappa))
                for d in range(self.obs_params["document_sequence_length"]):
                    x = self.obs_data["x"][d, ..., ...]  # shape=(I, X)
                    if d == 0:
                        gamma_loc = delta_k_q + delta_kappa_q * kappa_q + delta_rho_q * (x @ rho_q.T
                                                                                         )  # shape=(I, K)
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau_q))
                        gamma.append(gamma_d)
                    else:
                        timedelta = self.obs_data["timedelta"][d - 1, ...]
                        time_decay_coef = alpha_q * (-beta_q.unsqueeze(-3) *
                                                     timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
                        last_gamma = gamma[d - 1].unsqueeze(-1)
                        time_decay_term = torch.bmm(time_decay_coef, last_gamma)
                        time_decay_term = time_decay_term.squeeze(-1)
                        gamma_loc = kappa_q + time_decay_term + x @ rho_q.T
                        # gamma_loc = kappa_q + torch.bmm(time_decay_term, last_gamma).squeeze(-1) + x @ rho_q.T
                        gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau_q))
                        gamma.append(gamma_d)

        gamma = torch.stack(gamma)

        theta = F.softmax(gamma, dim=-1)

        with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1):  # I
            with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2):  # D
                with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3):  # N
                    word_topic = pyro.sample(f"topic_of_each_word",
                                             dist.Categorical(theta),
                                             infer={"enumerate": "parallel"})  # shape=(N, D, I)
                    p_word = Vindex(varphi_q)[word_topic]
                    word = pyro.sample(f"words", dist.Categorical(p_word))

When I tried to train this model, it looked good initially. But after a few iterations, the loss turns to a negative value and doesn’t seem to be converging. I used the pouting.uncondition() func to reuse the pyro model to generate the simulation data to validate my model. Now I’m confused. I don’t know if there is a problem with the generated simulation data or if there is an error in my code.

[2022-09-19 19:33:29 INFO]: On iteration 617, loss = 108.1953125
[2022-09-19 19:33:30 INFO]: On iteration 618, loss = 148.8046875
[2022-09-19 19:33:30 INFO]: On iteration 619, loss = 136.1484375
[2022-09-19 19:33:30 INFO]: On iteration 620, loss = 76.56640625
[2022-09-19 19:33:30 INFO]: On iteration 621, loss = 117.4921875
[2022-09-19 19:33:30 INFO]: On iteration 622, loss = 89.83984375
[2022-09-19 19:33:30 INFO]: On iteration 623, loss = 102.1171875
[2022-09-19 19:33:30 INFO]: On iteration 624, loss = 44.41796875
[2022-09-19 19:33:31 INFO]: On iteration 625, loss = 51.1875
[2022-09-19 19:33:31 INFO]: On iteration 626, loss = 112.4296875
[2022-09-19 19:33:31 INFO]: On iteration 627, loss = 73.48828125
[2022-09-19 19:33:31 INFO]: On iteration 628, loss = 43.6484375
[2022-09-19 19:33:31 INFO]: On iteration 629, loss = 51.51171875
[2022-09-19 19:33:31 INFO]: On iteration 630, loss = 6.7734375
[2022-09-19 19:33:32 INFO]: On iteration 631, loss = 24.2265625
[2022-09-19 19:33:32 INFO]: On iteration 632, loss = 78.203125
[2022-09-19 19:33:32 INFO]: On iteration 633, loss = -18.22265625
[2022-09-19 19:33:32 INFO]: On iteration 634, loss = -36.3828125
[2022-09-19 19:33:32 INFO]: On iteration 635, loss = -10.390625
[2022-09-19 19:33:32 INFO]: On iteration 636, loss = -54.13671875
[2022-09-19 19:33:32 INFO]: On iteration 637, loss = -52.375
[2022-09-19 19:33:33 INFO]: On iteration 638, loss = -20.8984375
[2022-09-19 19:33:33 INFO]: On iteration 639, loss = -10.453125
[2022-09-19 19:33:33 INFO]: On iteration 640, loss = -97.80859375
[2022-09-19 19:33:33 INFO]: On iteration 641, loss = -41.1796875
[2022-09-19 19:33:33 INFO]: On iteration 642, loss = -17.46875
[2022-09-19 19:33:33 INFO]: On iteration 643, loss = -35.69921875
[2022-09-19 19:33:34 INFO]: On iteration 644, loss = -126.4375

I haven’t looked at your model, but in general:

  • It’s perfectly reasonable for the ELBO loss to be negative if you have any continuous variables. The ELBO loss term for a continuous variable is a negative log-density. Depending on units, the log density can be arbitrarily shifted up or down. For example switching from inches to meters for some parameter will shift by log(inches/meter).
  • When using Pyro’s SVI, the ELBO loss is a stochastic estimate. If you want a more accurate estimate for evaluation, you can set num_particles to a large number, but that would slow down learning.

Thanks for your reply! This problem has been solved. The reason for the negative loss is the parameter of the Dirichlet distribution in this model. I set up a symmetric parameter for the Dirichlet distribution when generating the simulation data, which seems hard to learn via SVI, and got a diverging trend in the loss plot. I modified this parameter and got the correct loss and now my model works well!

1 Like