Found an issue with Dirichlet distribution. Sometimes, for no apparent reason, it outputs a tensor full of NaNs. The line causing the problem is as simple as:
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
topic_weights is asserted to be a valid tensor, always.
Now, I’m circumventing the problem by replacing the line above by the following code:
repeat = True i = 0 while repeat: doc_topics = pyro.sample("doc_topics_%d" % i, dist.Dirichlet(topic_weights)) repeat = doc_topics.isnan().any() i += 1
It works… but it is as ugly as code can possibly be. (In Brazil we have a slang for that: “gambiarra” )
Any hints on how to solve that?
I observed someone already had the same issue in the past, and one of the core devs suggested him to post the question in PyTorch forum (which he didn’t). Anyway, I will post the same question in PyTorch forum…