Memoize guide precalculation; gradients of gradients (question in 2 posts)


I have 2.5 related issues, all of them in implementing the idea of a Laplace family guide as described here:

Here’s how I understand the basic control flow of how pyro works:

Outer Adam loop over n in range(num_optim_steps) to optimize over φ:
    inner loop over s in range(num_samples):
        run guide with parameters φⁿ (n is index, not exponent):
            guide does some setup calculation to figure out the distribution for q_φⁿ. 
                In my case, since φ is θ̂, this involves getting the Hessian of p(x,θ̂) — SLOW
            guide gets 1 sample θⁿˢ from qᵩ — and keeps track of log[qᵩ(θⁿˢ)]
        run model to get log(p(x,θⁿˢ))
        calculate "crook" (the "inside" part of ELBO—inside the expectation)
    estimate ELBO as a weighted sum of all s crooks
    get gradient of that ELBO estimate
    pass all that up to Adam

I see a few problems with that.

Problem 1: the entire guide is being run num_samples times per step, even though the φ is identical for all those times, so slow setup calculation in the guide should give the same answer each time. In my case, this is a huge waste of time. I wish there were a way to split the guide into two parts, so that you could run the precalculation step once and get a function that you could then run num_samples times.

Problem 1a: If I memoized the precalculation result (in my case, the Hessian), it would become opaque to the ELBO gradient. Yuck.

Problem 2: since both pytorch and poutine appear to need hand-holding to be able to calculate gradients of gradients, I currently have to wrap the argument of the Hessian call in poutine.block(…) so that the “gradient of the ELBO” step doesn’t cause an error. But that means that the Hessian is then treated as something fixed, not something that itself has a gradient over φ. This can mean the outer optimization goes off in the wrong direction. I need to fix this, and I’d even be willing to hand-code something that calculated just one element of the Hessian matrix at a time in order to do so, but I don’t understand poutine sufficiently to be able to do this.

Thanks in advance for any help/hints that anyone can give me.


Note: It may be that I’m misunderstanding what poutine.block(…) actually does, and thus that I’ve misstated problem 2. Does it, or doesn’t it, make the Hessian appear to pyro as if it were a constant, with a zero gradient over φ?

One idea that I had was to manually calculate and memoize the Hessian and its gradient the first time around, then on later calls which have the same values for φ, to construct a linear form which has the same value and gradient. I know that’s a hack, but is it a crazy one?


Possible solutions include:

  • setting vectorized_particles=True;
  • using local sample sites via TraceEnum_ELBO with pyro.sample(..., infer={'num_samples': 100})
  • memoizing/caching your expensive computation.

Hmm just to eliminate possible confusion, note that poutine.block is a Pyro concept, rather than a PyTorch concept. It blocks sites/messages from propagating to enclosing Pyro effect handlers. It does not block gradients. When I want to block gradients, I usually use Tensor.detach() or torch.no_grad().

General advice because I don’t fully understand your issues :smile::
When I want to use second-order methods in Pyro, I use TraceEnum_ELBO.differentiable_loss() rather than SVI. A few of the Trace*_ELBO classes have a .differentiable_loss() method, but TraceEnum_ELBO's method is infinitely differentiable whereas the others are only once differentiable. Here is an example using pyro.ops.newton_step() which computes a Hessian internally.


Thanks for the answers. Both are helpful, even though they answer slightly different questions than the ones I asked.

Part of that is because I asked the wrong questions. I’ve now realized that you’re right, it wasn’t “block” that was making my Hessian opaque to gradients, but rather the Hessian code I was using. Unfortunately, I think there will be performance issues with what I want to do using any kind backwards-mode AD.