LDA model: NaNs and inf as loss

Hi

I am changing the LDA code but it keeps giving me NaN and inf loss. I checked the log_probs, they are infinite as well. Any pointers are appreciated.

The guide and model are given below
indent preformatted text by 4 spaces

def guide(self,users,annealing_factor=1.0):
K = 5
V = 5
# register PyTorch module with Pyro
pyro.module(“rnn”, self.rnn,update_module_params=False)
pyro.module(“autoEn”, self.autoEn,update_module_params=False)
pyro.module(“combiner”, self.combiner,update_module_params=True)

    topic_words_posterior = pyro.param("topic_words_posterior",torch.ones(K,V)/V, constraint=constraints.positive)    
    phi_q = pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior).independent(1))

    with pyro.plate("Users",len(users), 1) as u:
      
        # load data 
            
        #Pass through Rnn
        user_hist_rnn,hidden = self.rnn.forward(temp_data.view(1,temp_data.shape[0],temp_data.shape[1]),hidden)
            
        topic_weights_posterior = pyro.param("topic_weights_posterior_%d"%u,
                                        torch.ones(K)/K,
                                        constraint=constraints.simplex)
         alpha_q = pyro.sample("topic_weights",dist.Dirichlet(
                topic_weights_posterior))
            
         theta_d = pyro.sample("doc_topics",dist.Dirichlet(alpha_q))
            self.alphas.append(theta_d)
            
         with pyro.plate("time",len(user_hist_rnn)) as t:  

                zq = pyro.sample("word_topics", dist.Categorical(theta_d))
            
                bias_con = Variable(dist.Dirichlet(phi_q[zq,:]).sample()).type(self.dtype)

                temp_data = Variable(torch.Tensor(self.read_user_data(
                    users[u]))).type(self.dtype)
                
                usr_con = self.autoEn.forward_en(temp_data)
                
                alpha_f = (self.combiner(bias_con.type(self.dtype).view(1,K),
                               user_hist_rnn[:,t,:].view(len(t),user_hist_rnn.shape[2]),
                               usr_con)) # the final layer is a softplus
                
                
                
                pyro.sample("doc_words",dist.Dirichlet(alpha_f).independent(1))
            

def model(self,users,annealing_factor=1.0):
    K = 5
    V = 5
    % read data

    # Globals.
    
    with pyro.plate("topics",K):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / K, 1.))
        topic_words = pyro.sample("topic_words",dist.Dirichlet(
            torch.ones(V) / V))
        
    topic_words = topic_words/sum(topic_words,0)                         
    
    with pyro.plate("Users",len(users), 1) as u:
        usr_hist_dt = usr_dt[users[u]]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))

        
        with pyro.plate("time",len(usr_hist_dt)) as t:    
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),infer={"enumerate": "parallel"})

            obs_val = Variable(usr_hist_dt['y'][t,:]).type(self.dtype)                    

            pyro.sample("doc_words", dist.Dirichlet(
                topic_words[word_topics,:].view(len(t),V)).independent(1),obs=obs_val)

Hi @javenu I also find that the LDA example is quite numerically unstable. It was originally intended only to demonstrate enumeration machinery, and is not the recommended way of implementing LDA. I think we should try to create a more stable LDA model together with a tutorial. Contributions welcome!

Hi @fritzo . Thanks for the reply. After some analysis, I observe that adding the obs to the model makes this unstable and gives the Nan and inf. Also the sum of the Dirichlet is not exactly one. Do you have any pointers on how we can condition the model on observation without introducing instability?