I’m trying to implement a dynamic-style topic model. Here is my model
code.
def _model(self):
"""
(Core) Implements the `model` structure of a Pyro model
"""
# subsample size
if self._subsample_size is None:
user_subsample_size = self.obs_params["I"]
else:
user_subsample_size = self._subsample_size["I"]
delta_k = pyro.sample("delta_k",
dist.Normal(self._model_prior["delta_k_loc"], self._model_prior["delta_k_scale"]))
delta_kappa = pyro.sample(
"delta_kappa", dist.Normal(self._model_prior["delta_kappa_loc"], self._model_prior["delta_kappa_scale"]))
delta_rho = pyro.sample("delta_rho",
dist.Normal(self._model_prior["delta_rho_loc"], self._model_prior["delta_rho_scale"]))
# initialize gamma
gamma = []
# sample document-topic distributions
with pyro.plate("topic_plate", self.obs_params["K"], dim=-1):
varphi = pyro.sample(f"varphi", dist.Dirichlet(self._model_prior["eta"]))
rho = pyro.sample(f"rho",
dist.MultivariateNormal(self._model_prior["rho_loc"], self._model_prior["rho_scale"]))
alpha = pyro.sample(
f"alpha", dist.MultivariateNormal(self._model_prior["alpha_loc"], self._model_prior["alpha_scale"]))
beta = pyro.sample(
f"beta", dist.MultivariateNormal(self._model_prior["beta_loc"], self._model_prior["beta_scale"]))
tau = pyro.sample(f"tau", dist.Gamma(self._model_prior["tau_a"], self._model_prior["tau_b"]))
with pyro.plate("user_plate", self.obs_params["I"], subsample_size=user_subsample_size,
dim=-2) as u_ind_model:
u_ind_model = u_ind_model.to(self.device)
kappa = pyro.sample(f"kappa",
dist.Normal(self._model_prior["kappa_loc"], self._model_prior["kappa_scale"]))
for d in range(self.obs_params["document_sequence_length"]):
x = self.obs_data["x"][d, ..., ...].index_select(0, u_ind_model) # shape=(I, X)
if d == 0:
gamma_loc = delta_k + delta_kappa * kappa + delta_rho * (x @ rho.T) # shape=(I, K)
gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
gamma.append(gamma_d)
else:
timedelta = self.obs_data["timedelta"][d - 1, ...].index_select(0, u_ind_model)
time_decay_coef = alpha * (-beta.unsqueeze(-3) * timedelta.unsqueeze(-1).unsqueeze(-1)).exp()
last_gamma = gamma[d - 1].unsqueeze(-1)
time_decay_term = torch.bmm(time_decay_coef, last_gamma)
time_decay_term = time_decay_term.squeeze(-1)
gamma_loc = kappa + time_decay_term + x @ rho.T
gamma_d = pyro.sample(f"gamma_{d}", dist.Normal(gamma_loc, tau))
gamma.append(gamma_d)
gamma = torch.stack(gamma)
theta = F.softmax(gamma, dim=-1) # shape=(D, I, K), document topic distribution
if self._uncondition_flag is False:
obs_words = self.obs_data["text"] # shape=(N, D, I)
if self._uncondition_flag is True:
obs_words = None
# sample document word
with pyro.plate("user_plate_2", self.obs_params["I"], subsample_size=user_subsample_size, dim=-1): # I
with pyro.plate("document_sequence_length", self.obs_params["document_sequence_length"], dim=-2): # D
with pyro.plate("docs_length", self.obs_params["document_length"], dim=-3): # N
word_topic = pyro.sample(f"topic_of_each_word",
dist.Categorical(theta),
infer={"enumerate": "parallel"}) # shape=(N, D, I)
p_word = Vindex(varphi)[word_topic]
if obs_words is None:
word = pyro.sample(f"words", dist.Categorical(p_word), obs=obs_words)
else:
word = pyro.sample(f"words",
dist.Categorical(p_word),
obs=obs_words.index_select(2, u_ind_model))
if self._uncondition_flag is True:
self.obs_data["text"] = word
To draw topic proportion, I used subsampling
for user_plate
before # sample document word
. I’m wondering if I need to continue to use subsampling
in the procedure of draw word
, or just need to set the size to subsampling size. The corresponding code is:
with pyro.plate("user_plate", self.obs_params["I"], subsample_size=user_subsample_size,
dim=-2) as u_ind_model:
# sample document word
with pyro.plate("user_plate_2", self.obs_params["I"], subsample_size=user_subsample_size, dim=-1): # I