# Numpyro implementation for Hashingmarginal and Search inference in Pyro Rational Speech Act (RSA) examples

Dear colleges in the community,

Short question: Is there a way to implement Hashing Marginal and Search inference used in pyro RSA examples with numpyro’s functionality?

Long background: I’m trying to do Bayesian Regression on an incremental Rational Speech Act (RSA) model in . My starting point is this tutorial for RSA: The Rational Speech Act framework — Pyro Tutorials 1.8.5 documentation. I have first implemented models in pyro in the fashion of that tutorial and performed MCMC inference with NUTS. The initial results were extremly slow and unstable across chains: the fastest chain took 4 hours for only 4 samples, and sometimes warm up phases took 7 hours. Then I read that numpyro is much more efficient for perform MCMC inference, and some extra utilites to deal with discrete variable seem very useful. I come with the idea to implement models from pyro to numpyro, but stucks at finding substitutes for these two key components from RSA examples which I’m heavily rely on. My models look like in following:

``````

@Marginal
def literal_listener(words, states, color_semvalue = 0.98, form_semvalue = 0.98, wf = 0.6, k = 0.5,):
if len(words.split()) <= 1:
current_state_prior = state_prior(states)
current_word = words
else:
current_state_prior = literal_listener(words.split()[1:], states)
current_word = words.split()
obj = pyro.sample("obj", current_state_prior)
utt_truth_val = adjMeaning(current_word, obj, current_state_prior, color_semvalue, form_semvalue, wf, k)
pyro.factor("literal_meaning", 0. if utt_truth_val == True else -9999999.)
return obj

@Marginal
def global_speaker_production(states, alpha, color_semvalue, form_semvalue, wf, k, cost_weight):
obj = states # We assume that the target object is always the first one in the list
with poutine.scale(scale=torch.tensor(alpha)):
utterance = pyro.sample("utterance", utterance_prior())
pyro.factor("listener", alpha * (literal_listener(utterance,states, color_semvalue, form_semvalue, k, wf).log_prob(obj) - cost(utterance, cost_weight)))
return  utterance
``````

 Schlotterbeck, F., & Wang, H. (2023). An Incremental RSA Model for Adjective Ordering Preferences in Referential Visual Context. Proceedings of the Society for Computation in Linguistics , 6 (1), 121-132.

doing something like this in numpyro would probably require a fair amount of work. probably easier to try to get things to work in pyro.

how many latent variables are in your model (total dimension, continuous vs discrete, etc)? hmc may not be the right algorithm if e.g. your posterior surface is very rough

as a starting point i’d try limiting `max_tree_depth` to something smaller like `4` or `6` and seeing if you can still get reasonable results using NUTS