Hey guys,
Found an issue with Dirichlet distribution. Sometimes, for no apparent reason, it outputs a tensor full of NaNs. The line causing the problem is as simple as:
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
topic_weights is asserted to be a valid tensor, always.
Now, I’m circumventing the problem by replacing the line above by the following code:
repeat = True
i = 0
while repeat:
doc_topics = pyro.sample("doc_topics_%d" % i, dist.Dirichlet(topic_weights))
repeat = doc_topics.isnan().any()
i += 1
It works… but it is as ugly as code can possibly be. (In Brazil we have a slang for that: “gambiarra”
)
Any hints on how to solve that?
I observed someone already had the same issue in the past, and one of the core devs suggested him to post the question in PyTorch forum (which he didn’t). Anyway, I will post the same question in PyTorch forum…
. We can go for whichever you feel is the most urgent. My personal preference goes to the first because it may have some direct applications in my work — for instance it may help predict the next pathogen that will spillover to humans, something of great interest for the community. If you are game, we can discuss the details in private (my email address is easy to find on my blog).