I have 2.5 related issues, all of them in implementing the idea of a Laplace family guide as described here: https://github.com/pyro-ppl/pyro/issues/1817
Here’s how I understand the basic control flow of how pyro works:
Outer Adam loop over n in range(num_optim_steps) to optimize over φ: inner loop over s in range(num_samples): run guide with parameters φⁿ (n is index, not exponent): guide does some setup calculation to figure out the distribution for q_φⁿ. In my case, since φ is θ̂, this involves getting the Hessian of p(x,θ̂) — SLOW guide gets 1 sample θⁿˢ from qᵩ — and keeps track of log[qᵩ(θⁿˢ)] run model to get log(p(x,θⁿˢ)) calculate "crook" (the "inside" part of ELBO—inside the expectation) estimate ELBO as a weighted sum of all s crooks get gradient of that ELBO estimate pass all that up to Adam
I see a few problems with that.
Problem 1: the entire guide is being run num_samples times per step, even though the φ is identical for all those times, so slow setup calculation in the guide should give the same answer each time. In my case, this is a huge waste of time. I wish there were a way to split the guide into two parts, so that you could run the precalculation step once and get a function that you could then run num_samples times.
Problem 1a: If I memoized the precalculation result (in my case, the Hessian), it would become opaque to the ELBO gradient. Yuck.
Problem 2: since both pytorch and poutine appear to need hand-holding to be able to calculate gradients of gradients, I currently have to wrap the argument of the Hessian call in poutine.block(…) so that the “gradient of the ELBO” step doesn’t cause an error. But that means that the Hessian is then treated as something fixed, not something that itself has a gradient over φ. This can mean the outer optimization goes off in the wrong direction. I need to fix this, and I’d even be willing to hand-code something that calculated just one element of the Hessian matrix at a time in order to do so, but I don’t understand poutine sufficiently to be able to do this.
Thanks in advance for any help/hints that anyone can give me.