Dear colleges in the community,
Short question: Is there a way to implement Hashing Marginal and Search inference used in pyro RSA examples with numpyro’s functionality?
Long background: I’m trying to do Bayesian Regression on an incremental Rational Speech Act (RSA) model in [1]. My starting point is this tutorial for RSA: The Rational Speech Act framework — Pyro Tutorials 1.8.5 documentation. I have first implemented models in pyro in the fashion of that tutorial and performed MCMC inference with NUTS. The initial results were extremly slow and unstable across chains: the fastest chain took 4 hours for only 4 samples, and sometimes warm up phases took 7 hours. Then I read that numpyro is much more efficient for perform MCMC inference, and some extra utilites to deal with discrete variable seem very useful. I come with the idea to implement models from pyro to numpyro, but stucks at finding substitutes for these two key components from RSA examples which I’m heavily rely on. My models look like in following:
@Marginal
def literal_listener(words, states, color_semvalue = 0.98, form_semvalue = 0.98, wf = 0.6, k = 0.5,):
if len(words.split()) <= 1:
current_state_prior = state_prior(states)
current_word = words
else:
current_state_prior = literal_listener(words.split()[1:][0], states)
current_word = words.split()[0]
obj = pyro.sample("obj", current_state_prior)
utt_truth_val = adjMeaning(current_word, obj, current_state_prior, color_semvalue, form_semvalue, wf, k)
pyro.factor("literal_meaning", 0. if utt_truth_val == True else -9999999.)
return obj
@Marginal
def global_speaker_production(states, alpha, color_semvalue, form_semvalue, wf, k, cost_weight):
obj = states[0] # We assume that the target object is always the first one in the list
with poutine.scale(scale=torch.tensor(alpha)):
utterance = pyro.sample("utterance", utterance_prior())
pyro.factor("listener", alpha * (literal_listener(utterance,states, color_semvalue, form_semvalue, k, wf).log_prob(obj) - cost(utterance, cost_weight)))
return utterance
[1] Schlotterbeck, F., & Wang, H. (2023). An Incremental RSA Model for Adjective Ordering Preferences in Referential Visual Context. Proceedings of the Society for Computation in Linguistics , 6 (1), 121-132.