Is there a technique for finding argmax of expected utility with inference in Pyro?

The agentmodels web book discusses a technique that uses factor statements in Webppl to infer a discrete action that maximizes an expected utility function.

var softMaxAgent = function(state) {
  return Infer({ 
    model() {

      var action = uniformDraw(actions);

      var expectedUtility = function(action) {
        return expectation(Infer({ 
          model() {
            return utility(transition(state, action));
          }
        }));
      };
      
      factor(alpha * expectedUtility(action));
      
      return action;
    }
  });
};

I have the impression that users in Pyro are not meant to directly manipulate log-mass in this way. Is there an analogous pattern or abstraction in Pyro? Aside from something like Thompson sampling I mean.

Cheers

~ Robert

2 Likes

Hi @osazuwa,

To make this work, you’ll need two components: (1) a factor statement and (2) nested expectations.

You can simulate a factor(utility) statement in Pyro using

pyro.sample("foo", dist.Delta(torch.tensor(0.), log_density=utility),
            obs=torch.tensor(0.))

Also I’ll try adding pyro.factor as a primitive. I agree this should be easier.

To compute softmax nested expectations at temperature 1/alpha, you can generate posterior samples from a pyro model with alpha * utility. Alternatively, in some cases you can use infer_discrete to compute optimal actions in discrete structured action spaces. I haven’t thought about how to compute the inner expectation; Pyro inference algorithms don’t nest as nicely as webppl inference algorithms.

Cheers,
Fritz

1 Like

Hi @osazuwa and @fritzo, I’m sorry for asking a dumb follow-up question, but I still struggle to understand quite basic things about Pyro. My question is: how do I compute an expectation of a distribution defined by a model at all in Pyro? Is there any equivalent to WebPPL’s “expectation” function?

I don’t know about webppl’s expectation function, but the typical way to compute an expectation of something in Pyro is to:

  1. create a latent variable model
  2. create a guide (or choose an autoguide) to serve as the variational family for inference
  3. fit the guide to your model + data using SVI
  4. monte carlo estimate the expectation of any function f(z) but drawing many samples z from the fitted guide, computing f(z), and averaging the results.

OK, but that would calculate the expectation outside the model. My problem is that I have a model which is a rather complex network in which many variables depend on the expectations of other variables. In WebPPL I can simply use the expectation operator inside the model to express this.

Hi again, I would appreciate any help whatsoever on this.

To clarify my question a little further:

Assume I have a model M1 for the conditional probability distribution of a random variable X given some other variable Y, and I have another model M2 for the probability distribution of another random variable Z that is parameterized by E(X|Y), the expected value of X given Y.

How would I use Pyro to find the marginal distribution of Z given Y? (In WebPPL I would simply use the expectation operator inside the model for Z)