- What tutorial are you running? LDA
- What version of Pyro are you using? 0.3
- Please link or paste relevant code, and steps to reproduce.
I believe that there is an error in the LDA code. More precisely, there is a constraint on the input parameters that has not been stated, which is that num-words must be greater than num-words-per-doc.
To see if this is true, run the the code twice:
- First run: num-words is 65. The code runs fine.
- Second run: num-words is 63. The code crashes.
The problem is with
counts.scatter_add_(0, data[:, ind], torch.tensor(1.).expand(counts.shape))
in the parametrized_guide .There is a restriction on the dimensions of the argument.
My question is: why is there such a restriction? How can I duplicate the functionality of scatter_add without this restriction? Thanks.