I have a model p(θ,y,z) with global parameters θ, unit-level latent variables yᵢ, and unit-level observations zᵢ. I’d like to define a guide family q_φ as follows:
- The guide parameter φ is just θ̂ , the mode of θ.
- Calculate ŷ, the MLE of y given θ=θ̂ and z=z_obs. (That is, amortize over y, using calculus I have done symbolically.)
- Find the observed information matrix P, i.e. the negative of the Hessian of p with respect to (θ,y) at the point (θ̂,ŷ,z_obs).
- Let q_φ(θ,y) be normal with mean (θ̂,ŷ) and precision matrix P. (In reality, we may need to modify P somewhat, to ensure that it is finite and positive-definite. But for the present purposes, let’s just assume that it is.)
I have successfully gotten pyro to do an un-amortized version of this approach — i.e. letting φ = (θ̂,ŷ) rather than calculating ŷ from θ̂.
The difficulty with the amortized version is that, when calculating the Hessian of p at (θ̂,ŷ), pyro needs to forget about the dependence of ŷ on θ̂: it needs to treat ∂ŷ/∂θ̂ as 0. On the other hand, when calculating the gradient of the ELBO with respect to θ̂, the dependence of ŷ on θ̂ is crucial.
I know how to make the dependence of ŷ on θ̂ invisible to pyro/ pytorch, but what I need to do here is make it visible for some purposes and invisible for others. I can think of a few possibilities:
A. Make the dependence invisible in the guide function itself, but calculate the missing terms of the gradient separately and pass them to pyro explicitly. (Using some Poutine call?? Not sure how.)
B. Calculate an
extra_loss variable in such a way that running
extra_loss.backward(retain_variables=True) adds the correct amount to the gradient. Trust that pyro will not zero out that gradient before using it.
C. Add some hooks in the pyro.param class (in my private copy, course), so that a guide can ask for two identical copies of a given parameter and use them in different places. Define the gradient of a function with respect to such a “doubled” parameter be the sum of the gradients with respect to each copy, as calculated by .backward().
The issue with A and B is that I can easily imagine getting things wrong, since I don’t fully understand all the scaling and Rao-Blackwellization involved in pyro’s estimate of the ELBO. What exactly is the expression for the thing that pyro runs
backward() on? (The ELBO or surrogate ELBO? I think?)
The issue with C is that it involves changing pyro itself. Since I’m not a pyro pro, this seems risky. Still, I currently lean towards option C, because it seems the simplest; but I may be totally wrong about that.
Would really appreciate advice and words of wisdom and warning!