I’m trying to build a time-series topic model, which seems like the time-series correlated topic model (CTM). Here is my model
and guide
code:
def _model(self, include_prior=True):
"""
Implements the data generation process by pyro model
"""
# loc_mu_kappa = torch.zeros(self.obs_params["K"], device=self.device) * 0.5
# scale_mu_kappa = torch.eye(self.obs_params["K"], device=self.device) * 0.05
# mu_kappa = pyro.sample(f"mu_kappa", dist.MultivariateNormal(loc_mu_kappa, scale_mu_kappa))
# n_Lambda_kappa = torch.tensor(3., device=self.device) * self.obs_params["K"]
# V_Lambda_kappa = torch.eye(self.obs_params["K"], device=self.device) * (1. / n_Lambda_kappa)
# Lambda_kappa = pyro.sample(f"Lambda_kappa", dist.Wishart(df=n_Lambda_kappa, covariance_matrix=V_Lambda_kappa))
with poutine.mask(mask=include_prior):
delta_k = pyro.sample(
"delta_k", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))
delta_kappa = pyro.sample(
"delta_kappa", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))
delta_rho = pyro.sample(
"delta_rho", dist.Normal(torch.tensor(.5, device=self.device), torch.tensor(0.05, device=self.device)))
eta = torch.tensor([1 / self.obs_params["V"]] * self.obs_params["V"], device=self.device)
loc_rho = torch.tensor([.5] * self.obs_params["X"], device=self.device)
scale_rho = torch.eye(self.obs_params["X"], device=self.device) * 0.05
loc_alpha = torch.tensor([1.] * self.obs_params["K"], device=self.device)
scale_alpha = torch.eye(self.obs_params["K"], device=self.device) * 0.05
loc_beta = torch.tensor([5.] * self.obs_params["K"], device=self.device)
scale_beta = torch.eye(self.obs_params["K"], device=self.device) * 0.5
a_tau = torch.tensor(2., device=self.device)
b_tau = torch.tensor(1., device=self.device)
# initialize gamma
gamma = []
# sample document-topic distributions
with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
varphi = pyro.sample(f"varphi", dist.Dirichlet(eta))
rho = pyro.sample(f"rho", dist.MultivariateNormal(loc_rho, scale_rho))
alpha = pyro.sample(f"alpha", dist.MultivariateNormal(loc_alpha, scale_alpha))
beta = pyro.sample(f"beta", dist.MultivariateNormal(loc_beta, scale_beta))
tau = pyro.sample(f"tau", dist.Gamma(a_tau, b_tau))
with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
kappa = pyro.sample(
f"kappa", dist.Normal(torch.tensor(1., device=self.device), torch.tensor(.5, device=self.device)))
for d in range(self.obs_params["document_sequence_length"]):
x = self.obs_data["x"][d, ..., ...] # shape=(I, X)
if d == 0:
gamma_loc = delta_k + delta_kappa * kappa + delta_rho * (x @ rho.T) # shape=(I, K)
gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
gamma.append(gamma_d)
else:
timedelta = self.obs_data["timedelta"][d - 1, ...]
time_decay_coef = alpha * (-beta.unsqueeze(-3) * timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
last_gamma = gamma[d - 1].unsqueeze(-1)
time_decay_term = torch.bmm(time_decay_coef, last_gamma)
time_decay_term = time_decay_term.squeeze(-1)
gamma_loc = kappa + time_decay_term + x @ rho.T
# gamma_loc = kappa + torch.bmm(time_decay_coef, last_gamma).squeeze(-1) + x @ rho.T
gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
gamma.append(gamma_d)
gamma = torch.stack(gamma)
theta = F.softmax(gamma, dim=-1) # shape=(D, I, K), document topic distribution
if self._uncondition_flag is False:
obs_words = self.obs_data["text"] # shape=(N, D, I)
if self._uncondition_flag is True:
obs_words = None
# sample document word
with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1): # I
with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2): # D
with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3): # N
word_topic = pyro.sample(f"topic_of_each_word",
dist.Categorical(theta),
infer={"enumerate": "parallel"}) # shape=(N, D, I)
p_word = Vindex(varphi)[word_topic]
word = pyro.sample(f"words", dist.Categorical(p_word), obs=obs_words)
if self._uncondition_flag is True:
self.obs_data["text"] = word
def _guide(self, use_autoguide=True):
if use_autoguide:
return AutoDelta(self._model)
else:
# learnable parameters
delta_k_loc = pyro.param(f"delta_k_loc", lambda: torch.tensor(0.5, device=self.device))
delta_k_scale = pyro.param(f"delta_k_scale",
lambda: torch.tensor(0.05, device=self.device),
constraint=constraints.greater_than(0.0))
delta_kappa_loc = pyro.param(f"delta_kappa_loc", lambda: torch.tensor(0.5, device=self.device))
delta_kappa_scale = pyro.param(f"delta_kappa_scale",
lambda: torch.tensor(0.05, device=self.device),
constraint=constraints.greater_than(0.0))
delta_rho_loc = pyro.param(f"delta_rho_loc", lambda: torch.tensor(0.5, device=self.device))
delta_rho_scale = pyro.param(f"delta_rho_scale",
lambda: torch.tensor(0.05, device=self.device),
constraint=constraints.greater_than(0.0))
eta_v = pyro.param(
f"eta_v",
lambda: torch.tensor([1 / self.obs_params["V"]] * self.obs_params["V"], device=self.device),
constraint=constraints.simplex)
loc_rho = pyro.param(f"loc_rho", lambda: torch.tensor([.5] * self.obs_params["X"], device=self.device))
scale_rho = pyro.param(f"scale_rho",
lambda: torch.eye(self.obs_params["X"], device=self.device) * 0.05,
constraint=constraints.positive)
loc_alpha = pyro.param(f"loc_alpha",
lambda: torch.tensor([1.] * self.obs_params["K"], device=self.device))
scale_alpha = pyro.param(f"scale_alpha",
lambda: torch.eye(self.obs_params["K"], device=self.device) * 0.05,
constraint=constraints.positive)
loc_beta = pyro.param(f"loc_beta",
lambda: torch.tensor([2.] * self.obs_params["K"], device=self.device),
constraint=constraints.greater_than(1.0))
scale_beta = pyro.param(f"scale_beta",
lambda: torch.eye(self.obs_params["K"], device=self.device) * 0.5,
constraint=constraints.positive)
a_tau = pyro.param(f"a_tau",
lambda: torch.tensor(1., device=self.device),
constraint=constraints.greater_than(0.0))
b_tau = pyro.param(f"b_tau",
lambda: torch.tensor(1., device=self.device),
constraint=constraints.greater_than(0.0))
loc_kappa = pyro.param(f"loc_kappa", lambda: torch.tensor(0., device=self.device))
scale_kappa = pyro.param(f"scale_kappa",
lambda: torch.tensor(0.05, device=self.device),
constraint=constraints.greater_than(0.0))
# rvs
delta_k_q = pyro.sample(f"delta_k", dist.Normal(delta_k_loc, delta_k_scale))
delta_kappa_q = pyro.sample(f"delta_kappa", dist.Normal(delta_kappa_loc, delta_kappa_scale))
delta_rho_q = pyro.sample(f"delta_rho", dist.Normal(delta_rho_loc, delta_rho_scale))
gamma = []
with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
varphi_q = pyro.sample(f"varphi", dist.Dirichlet(eta_v))
rho_q = pyro.sample(f"rho", dist.MultivariateNormal(loc_rho, scale_rho))
alpha_q = pyro.sample(f"alpha", dist.MultivariateNormal(loc_alpha, scale_alpha))
beta_q = pyro.sample(f"beta", dist.MultivariateNormal(loc_beta, scale_beta))
tau_q = pyro.sample(f"tau", dist.Gamma(a_tau, b_tau))
with pyro.plate("user_plate", self.obs_params["I"], dim=-2):
kappa_q = pyro.sample(f"kappa", dist.Normal(loc_kappa, scale_kappa))
for d in range(self.obs_params["document_sequence_length"]):
x = self.obs_data["x"][d, ..., ...] # shape=(I, X)
if d == 0:
gamma_loc = delta_k_q + delta_kappa_q * kappa_q + delta_rho_q * (x @ rho_q.T
) # shape=(I, K)
gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau_q))
gamma.append(gamma_d)
else:
timedelta = self.obs_data["timedelta"][d - 1, ...]
time_decay_coef = alpha_q * (-beta_q.unsqueeze(-3) *
timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
last_gamma = gamma[d - 1].unsqueeze(-1)
time_decay_term = torch.bmm(time_decay_coef, last_gamma)
time_decay_term = time_decay_term.squeeze(-1)
gamma_loc = kappa_q + time_decay_term + x @ rho_q.T
# gamma_loc = kappa_q + torch.bmm(time_decay_term, last_gamma).squeeze(-1) + x @ rho_q.T
gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau_q))
gamma.append(gamma_d)
gamma = torch.stack(gamma)
theta = F.softmax(gamma, dim=-1)
with pyro.plate("user_plate_2", self.obs_params["I"], dim=-1): # I
with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2): # D
with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3): # N
word_topic = pyro.sample(f"topic_of_each_word",
dist.Categorical(theta),
infer={"enumerate": "parallel"}) # shape=(N, D, I)
p_word = Vindex(varphi_q)[word_topic]
word = pyro.sample(f"words", dist.Categorical(p_word))
When I tried to train this model, it looked good initially. But after a few iterations, the loss turns to a negative value and doesn’t seem to be converging. I used the pouting.uncondition()
func to reuse the pyro model to generate the simulation data to validate my model. Now I’m confused. I don’t know if there is a problem with the generated simulation data or if there is an error in my code.
[2022-09-19 19:33:29 INFO]: On iteration 617, loss = 108.1953125
[2022-09-19 19:33:30 INFO]: On iteration 618, loss = 148.8046875
[2022-09-19 19:33:30 INFO]: On iteration 619, loss = 136.1484375
[2022-09-19 19:33:30 INFO]: On iteration 620, loss = 76.56640625
[2022-09-19 19:33:30 INFO]: On iteration 621, loss = 117.4921875
[2022-09-19 19:33:30 INFO]: On iteration 622, loss = 89.83984375
[2022-09-19 19:33:30 INFO]: On iteration 623, loss = 102.1171875
[2022-09-19 19:33:30 INFO]: On iteration 624, loss = 44.41796875
[2022-09-19 19:33:31 INFO]: On iteration 625, loss = 51.1875
[2022-09-19 19:33:31 INFO]: On iteration 626, loss = 112.4296875
[2022-09-19 19:33:31 INFO]: On iteration 627, loss = 73.48828125
[2022-09-19 19:33:31 INFO]: On iteration 628, loss = 43.6484375
[2022-09-19 19:33:31 INFO]: On iteration 629, loss = 51.51171875
[2022-09-19 19:33:31 INFO]: On iteration 630, loss = 6.7734375
[2022-09-19 19:33:32 INFO]: On iteration 631, loss = 24.2265625
[2022-09-19 19:33:32 INFO]: On iteration 632, loss = 78.203125
[2022-09-19 19:33:32 INFO]: On iteration 633, loss = -18.22265625
[2022-09-19 19:33:32 INFO]: On iteration 634, loss = -36.3828125
[2022-09-19 19:33:32 INFO]: On iteration 635, loss = -10.390625
[2022-09-19 19:33:32 INFO]: On iteration 636, loss = -54.13671875
[2022-09-19 19:33:32 INFO]: On iteration 637, loss = -52.375
[2022-09-19 19:33:33 INFO]: On iteration 638, loss = -20.8984375
[2022-09-19 19:33:33 INFO]: On iteration 639, loss = -10.453125
[2022-09-19 19:33:33 INFO]: On iteration 640, loss = -97.80859375
[2022-09-19 19:33:33 INFO]: On iteration 641, loss = -41.1796875
[2022-09-19 19:33:33 INFO]: On iteration 642, loss = -17.46875
[2022-09-19 19:33:33 INFO]: On iteration 643, loss = -35.69921875
[2022-09-19 19:33:34 INFO]: On iteration 644, loss = -126.4375