Numpyro implementation for Hashingmarginal and Search inference in Pyro Rational Speech Act (RSA) examples

Dear colleges in the community,

Short question: Is there a way to implement Hashing Marginal and Search inference used in pyro RSA examples with numpyro’s functionality?

Long background: I’m trying to do Bayesian Regression on an incremental Rational Speech Act (RSA) model in [1]. My starting point is this tutorial for RSA: The Rational Speech Act framework — Pyro Tutorials 1.8.5 documentation. I have first implemented models in pyro in the fashion of that tutorial and performed MCMC inference with NUTS. The initial results were extremly slow and unstable across chains: the fastest chain took 4 hours for only 4 samples, and sometimes warm up phases took 7 hours. Then I read that numpyro is much more efficient for perform MCMC inference, and some extra utilites to deal with discrete variable seem very useful. I come with the idea to implement models from pyro to numpyro, but stucks at finding substitutes for these two key components from RSA examples which I’m heavily rely on. My models look like in following:

def literal_listener(words, states, color_semvalue = 0.98, form_semvalue = 0.98, wf = 0.6, k = 0.5,):
    if len(words.split()) <= 1:
        current_state_prior = state_prior(states)
        current_word = words
        current_state_prior = literal_listener(words.split()[1:][0], states)
        current_word = words.split()[0]
    obj = pyro.sample("obj", current_state_prior)
    utt_truth_val = adjMeaning(current_word, obj, current_state_prior, color_semvalue, form_semvalue, wf, k)
    pyro.factor("literal_meaning", 0. if utt_truth_val == True else -9999999.)
    return obj

def global_speaker_production(states, alpha, color_semvalue, form_semvalue, wf, k, cost_weight):
    obj = states[0] # We assume that the target object is always the first one in the list
    with poutine.scale(scale=torch.tensor(alpha)):
        utterance = pyro.sample("utterance", utterance_prior())
        pyro.factor("listener", alpha * (literal_listener(utterance,states, color_semvalue, form_semvalue, k, wf).log_prob(obj) - cost(utterance, cost_weight)))
    return  utterance

[1] Schlotterbeck, F., & Wang, H. (2023). An Incremental RSA Model for Adjective Ordering Preferences in Referential Visual Context. Proceedings of the Society for Computation in Linguistics , 6 (1), 121-132.

doing something like this in numpyro would probably require a fair amount of work. probably easier to try to get things to work in pyro.

how many latent variables are in your model (total dimension, continuous vs discrete, etc)? hmc may not be the right algorithm if e.g. your posterior surface is very rough

Thanks for your reply!
This estimate is extremly helpful for me with plan making.

doing something like this in numpyro would probably require a fair amount of work

I have few more follow up questions:

  1. I am curious to know whether Pyro’s MCMC inference is compatible with nested reasoning, recursion, and caching. Considering that Bayesian back-and-forth reasoning plays an important role in RSA, and the fact that our incremental RSA models employ recursion at various levels, it’s crucial to understand how these interact with Pyro’s MCMC inference. Any insights on this would be greatly appreciated

if e.g. your posterior surface is very rough

  1. I’m not familiar with the geometry of posterior in generell. But in my models, two parameters can specifically drive predicted values in opposing directions. Any further information on this topic would be appreciated.

as a starting point i’d try limiting max_tree_depth to something smaller like 4 or 6 and seeing if you can still get reasonable results using NUTS