Advice on implementing sparse dirichlet distribution in pyro

I am interested in implementing a model like the one described here in pyro:

This model does something common in modern LDA implementations. It leverages the sparsity of the topic-word distributions to allow a large number of topics despite a large corpus. Without sparsity this would require o(K * V) memory which is intractable for moderate K and large V.

In the appendix of they develop a method to sparsely sample and update the vocabulary (see page 8).

I am wondering if it would be possible to implement something like this in pyro that would work with standard torch optimizers? Or with sparse optimizers? I think it should be possible if I write my own distribution and write a custom backprop rule, but I am not sure.

Any guidance here would be hugely appreciated!