Hi
I am changing the LDA code but it keeps giving me NaN and inf loss. I checked the log_probs, they are infinite as well. Any pointers are appreciated.
The guide and model are given below
indent preformatted text by 4 spaces
def guide(self,users,annealing_factor=1.0):
K = 5
V = 5
# register PyTorch module with Pyro
pyro.module(“rnn”, self.rnn,update_module_params=False)
pyro.module(“autoEn”, self.autoEn,update_module_params=False)
pyro.module(“combiner”, self.combiner,update_module_params=True)
topic_words_posterior = pyro.param("topic_words_posterior",torch.ones(K,V)/V, constraint=constraints.positive)
phi_q = pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior).independent(1))
with pyro.plate("Users",len(users), 1) as u:
# load data
#Pass through Rnn
user_hist_rnn,hidden = self.rnn.forward(temp_data.view(1,temp_data.shape[0],temp_data.shape[1]),hidden)
topic_weights_posterior = pyro.param("topic_weights_posterior_%d"%u,
torch.ones(K)/K,
constraint=constraints.simplex)
alpha_q = pyro.sample("topic_weights",dist.Dirichlet(
topic_weights_posterior))
theta_d = pyro.sample("doc_topics",dist.Dirichlet(alpha_q))
self.alphas.append(theta_d)
with pyro.plate("time",len(user_hist_rnn)) as t:
zq = pyro.sample("word_topics", dist.Categorical(theta_d))
bias_con = Variable(dist.Dirichlet(phi_q[zq,:]).sample()).type(self.dtype)
temp_data = Variable(torch.Tensor(self.read_user_data(
users[u]))).type(self.dtype)
usr_con = self.autoEn.forward_en(temp_data)
alpha_f = (self.combiner(bias_con.type(self.dtype).view(1,K),
user_hist_rnn[:,t,:].view(len(t),user_hist_rnn.shape[2]),
usr_con)) # the final layer is a softplus
pyro.sample("doc_words",dist.Dirichlet(alpha_f).independent(1))
def model(self,users,annealing_factor=1.0):
K = 5
V = 5
% read data
# Globals.
with pyro.plate("topics",K):
topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / K, 1.))
topic_words = pyro.sample("topic_words",dist.Dirichlet(
torch.ones(V) / V))
topic_words = topic_words/sum(topic_words,0)
with pyro.plate("Users",len(users), 1) as u:
usr_hist_dt = usr_dt[users[u]]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("time",len(usr_hist_dt)) as t:
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),infer={"enumerate": "parallel"})
obs_val = Variable(usr_hist_dt['y'][t,:]).type(self.dtype)
pyro.sample("doc_words", dist.Dirichlet(
topic_words[word_topics,:].view(len(t),V)).independent(1),obs=obs_val)