LDA-like model with multinomial distribution

Hi! I have an LDA-like model, in which I want to sample z and w with multinomial distributions. We know that in LDA, word is assumed to belong to a topic z. Therefore, In my model. I sampled z like this:

z_b_t = numpyro.sample(f"z_b_{t}",

The shape of z_b_t is (batch_size, topic_num), and the elements in topic_num dim represents the occurrences of each topic. I hope to implement the following procedure using multinomial distribution

# Choose topic-word distribution first, the topic-word distribution is represented as `phi` with shape is (topic_num, vocabulary_size)
p_z = Vindex(phi)[z_b_t]
# Sample word
word = numpyro.sample("word", dist.Categorical(p_z), obs=obs_word)

Because in the procedure of sampling word, we need to choose topic first, which seems hard to be implemented when using multinomial distribution in numpyro. I really appreciate it if you can help me!