Get document topic assignments from Prod LDA example

Hi! I’m trying to implement the Prod LDA example using this tutorial: Example: ProdLDA with Flax and Haiku — NumPyro documentation (thank you for providing this btw!)

Everything works perfectly, I’m just a bit stuck on figuring out from this code how you’d extract the document topics (i.e. given a document, what are the probabilities of each topic for that document / I’ve also seen this referred to as ‘per-document topic proportions’). I know these values have to be generated as part of running the model (right?!), I’m just unsure how you’d extract them out of this code.

I think what I need is theta, which is defined and present in the model and guide functions, I just don’t understand how you’d pull that out of the svi.run result to be able to map it back on to your corpus of documents.

Thanks so much in advance!

Hi @katester, it is beta (after talking softmax) in the code.

1 Like

Thanks so much! I’m still confused on how that maps back on to the documents: for example, I have a corpus of 110,280 documents, with 6,826 terms and a dictionary size of 6,775 looking at 15 topics. Beta after taking softmax has the shape (15, 6826), which I thought I should interpret as, for each topic, you have a vector of all of the terms with their corresponding probabilities. This makes sense to me too thinking about the graphing part in that you’re looking at per topic word probabilities.

However, what I’d expect for per document topic proportions would be something that in some dimension has 110,280 as a count, thinking how else could it be tied back to the individual documents themselves? Am I thinking about this correctly or missing something very stupid haha

Hopefully this makes sense and thanks again for the help I truly super appreciate it!

Oops, sorry @katester! You are right. It should be theta. I think one way you can do is to return it from the guide

def guide(docs, hyperparams, is_training=False, nn_framework="flax"):
    ...
    return concentration, theta

then substitute the optimized parameters:

with numpyro.handlers.seed(rng_seed=0), \
        numpyro.handlers.substitute(data=svi_result.params):
    concentration, theta = guide(docs, hyperparams)

concentration is the mean of the guide of theta, so you might want to use it instead of theta. This assumes you provide a batch of docs. For a single doc, you might want to add a singleton dimension to it.

1 Like

Ok last follow-up question (hopefully!) – I modified the guide function and inserted the substitution of the optimized parameters code in the main function (following the generation of svi_result via the run_inference function). Is this the right place to put it? I am getting an Assertion error which makes me think this is not the right place… specifically:

~/opt/anaconda3/envs/franken_tm/lib/python3.9/site-packages/numpyro/contrib/module.py in flax_module(name, nn_module, input_shape, apply_rng, mutable, **kwargs)
 nn_state = numpyro_mutable(name + "$state")
 assert nn_state is None or isinstance(nn_state, dict)
 assert (nn_state is None) == (nn_params is None)
 if nn_params is None:
 AssertionError:

Thank you so much again! I promise I’m doing my best to debug on my own, still wrapping my head around everything NumPyro :smiley:

Oops, the network has mutable state so we need to deal with it. I guess you can do

numpyro.handlers.substitute(data={**svi_result.params, **svi_result.state.mutable_state})

If it works, it would be nice to enhance the example to get document topic assignments.

1 Like

Woo hoo! It worked! I just had to make one small adjustment to the hyperparameters argument before running the guide function again: changing batch_size to docs.shape[0]. Thanks again for all of your help! I can modify the example on GitHub and put in a pull request if that would be helpful, just let me know :slight_smile:

Yeah, it would be super helpful because we don’t have any example for networks with mutable state. Btw, for prediction, you will need to set is_training=False.