Hey guys,
Found an issue with Dirichlet distribution. Sometimes, for no apparent reason, it outputs a tensor full of NaNs. The line causing the problem is as simple as:
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
topic_weights
is asserted to be a valid tensor, always.
Now, I’m circumventing the problem by replacing the line above by the following code:
repeat = True
i = 0
while repeat:
doc_topics = pyro.sample("doc_topics_%d" % i, dist.Dirichlet(topic_weights))
repeat = doc_topics.isnan().any()
i += 1
It works… but it is as ugly as code can possibly be. (In Brazil we have a slang for that: “gambiarra” )
Any hints on how to solve that?
I observed someone already had the same issue in the past, and one of the core devs suggested him to post the question in PyTorch forum (which he didn’t). Anyway, I will post the same question in PyTorch forum…