Is there a technique for finding argmax of expected utility with inference in Pyro?

The agentmodels web book discusses a technique that uses factor statements in Webppl to infer a discrete action that maximizes an expected utility function.

var softMaxAgent = function(state) {
  return Infer({ 
    model() {

      var action = uniformDraw(actions);

      var expectedUtility = function(action) {
        return expectation(Infer({ 
          model() {
            return utility(transition(state, action));
      factor(alpha * expectedUtility(action));
      return action;

I have the impression that users in Pyro are not meant to directly manipulate log-mass in this way. Is there an analogous pattern or abstraction in Pyro? Aside from something like Thompson sampling I mean.


~ Robert

1 Like

Hi @osazuwa,

To make this work, you’ll need two components: (1) a factor statement and (2) nested expectations.

You can simulate a factor(utility) statement in Pyro using

pyro.sample("foo", dist.Delta(torch.tensor(0.), log_density=utility),

Also I’ll try adding pyro.factor as a primitive. I agree this should be easier.

To compute softmax nested expectations at temperature 1/alpha, you can generate posterior samples from a pyro model with alpha * utility. Alternatively, in some cases you can use infer_discrete to compute optimal actions in discrete structured action spaces. I haven’t thought about how to compute the inner expectation; Pyro inference algorithms don’t nest as nicely as webppl inference algorithms.


1 Like