I am interested in implementing a model like the one described here in pyro: https://arxiv.org/pdf/1206.6425.pdf
This model does something common in modern LDA implementations. It leverages the sparsity of the topic-word distributions to allow a large number of topics despite a large corpus. Without sparsity this would require o(K * V) memory which is intractable for moderate K and large V.
In the appendix of they develop a method to sparsely sample and update the vocabulary (see page 8).
I am wondering if it would be possible to implement something like this in pyro that would work with standard torch optimizers? Or with sparse optimizers? I think it should be possible if I write my own distribution and write a custom backprop rule, but I am not sure.
Any guidance here would be hugely appreciated!