How should I manually adjust the gradient pyro sees?

I have a model p(θ,y,z) with global parameters θ, unit-level latent variables yᵢ, and unit-level observations zᵢ. I’d like to define a guide family q_φ as follows:

  1. The guide parameter φ is just θ̂ , the mode of θ.
  2. Calculate ŷ, the MLE of y given θ=θ̂ and z=z_obs. (That is, amortize over y, using calculus I have done symbolically.)
  3. Find the observed information matrix P, i.e. the negative of the Hessian of p with respect to (θ,y) at the point (θ̂,ŷ,z_obs).
  4. Let q_φ(θ,y) be normal with mean (θ̂,ŷ) and precision matrix P. (In reality, we may need to modify P somewhat, to ensure that it is finite and positive-definite. But for the present purposes, let’s just assume that it is.)

I have successfully gotten pyro to do an un-amortized version of this approach — i.e. letting φ = (θ̂,ŷ) rather than calculating ŷ from θ̂.

The difficulty with the amortized version is that, when calculating the Hessian of p at (θ̂,ŷ), pyro needs to forget about the dependence of ŷ on θ̂: it needs to treat ∂ŷ/∂θ̂ as 0. On the other hand, when calculating the gradient of the ELBO with respect to θ̂, the dependence of ŷ on θ̂ is crucial.

I know how to make the dependence of ŷ on θ̂ invisible to pyro/ pytorch, but what I need to do here is make it visible for some purposes and invisible for others. I can think of a few possibilities:

A. Make the dependence invisible in the guide function itself, but calculate the missing terms of the gradient separately and pass them to pyro explicitly. (Using some Poutine call?? Not sure how.)

B. Calculate an extra_loss variable in such a way that running extra_loss.backward(retain_variables=True) adds the correct amount to the gradient. Trust that pyro will not zero out that gradient before using it.

C. Add some hooks in the pyro.param class (in my private copy, course), so that a guide can ask for two identical copies of a given parameter and use them in different places. Define the gradient of a function with respect to such a “doubled” parameter be the sum of the gradients with respect to each copy, as calculated by .backward().

The issue with A and B is that I can easily imagine getting things wrong, since I don’t fully understand all the scaling and Rao-Blackwellization involved in pyro’s estimate of the ELBO. What exactly is the expression for the thing that pyro runs backward() on? (The ELBO or surrogate ELBO? I think?)

The issue with C is that it involves changing pyro itself. Since I’m not a pyro pro, this seems risky. Still, I currently lean towards option C, because it seems the simplest; but I may be totally wrong about that.

Would really appreciate advice and words of wisdom and warning!

It sounds like you could detach y before computing the Hessian, as in

y = y.detach().requires_grad_()

This is how I generally use newton_step() for example. See also the newton_step tests which compute a Hessian under the hood but detach beforehand to avoid endless backprop.

Thanks for the fast response. Using detach() is one way to calculate the hessian without accounting for ŷ’s dependency on θ̂. But as I said, I already know how to do that. What I want to do is treat yhat as detached when calculating the hessian, but then calculate the gradient that pyro itself uses to optimize φ as if yhat were not detached. Otherwise, the optimization goes wrong, increasing the bias of the final outcome.

I think that the answer will probably be some form of one of the three options I listed. Here’s pseudocode for each of them:

Current version, with detach but no way to “re-attach” the gradient:

def laplace_guide(observations):
    hat_data = OrderedDict() 
    #this will hold values to condition the model and get the Hessian
    
    theta_hat = pyro.param("mode_hat",torch.zeros(dim))
    hat_data.update(theta =theta_hat )

    y_hat = get_y_MLE(theta_hat, observations)

    detached_y_hat = y_hat.detach().requires_grad_()
    hat_data.update(y=detached_y_hat )

    hessCenter = pyro.condition(model,hat_data)
    blockedTrace = poutine.block(poutine.trace(hessCenter).get_trace)(observations)
    logPosterior = blockedTrace.log_prob_sum()
    Info = ensure_positive_definite(-myhessian.hessian(logPosterior, hat_data.values()))

    theta_and_y_mean = torch.cat([thetaPart.view(-1) for thetaPart in hat_data.values()],0)
    theta_and_y = pyro.sample('theta',
                    dist.MultivariateNormal(thetaMean,
                                precision_matrix=Info,
                    infer={'is_auxiliary': True})

    #decompose theta_and_y into specific values
    theta = theta_and_y[:dim], tmptheta[elems:]
    pyro.sample('theta', dist.Delta(theta).to_event(1))

Option A, add at the end:

    guide_density = get_multivariate_normal_density(theta_and_y-theta_and_y_mean, Info)
    guide_density.backward()
    torch.zero_out_grad(theta_hat)
    loss_transfer_value = detached_y_hat.grad * y_hat
    loss_transfer_value.backward()
    poutine.tell_pyro_to_add_this_to_gradient(theta_hat.grad)
    torch.zero_out_grad(theta_hat)

Option B: as above, but without the last two lines.

Option C, replace:

    theta_hat = pyro.param("mode_hat",torch.zeros(dim))
    hat_data.update(theta =theta_hat )

    y_hat = get_y_MLE(theta_hat, observations)

    detached_y_hat = y_hat.detach().requires_grad_()

with:

    theta_hat, theta_hat_copy = pyro.param("mode_hat",torch.zeros(dim),num_copies=2)
    hat_data.update(theta =theta_hat )

    y_hat = get_y_MLE(theta_hat_copy, observations)

Then, when calculating the Hessian, y_hat and theta_hat will not be attached in the backprop graph, but when pyro calculates the gradient, the connection between y_hat and theta_hat_copy will be present so the correct gradient for theta_hat will just be the sum of the gradients on theta_hat and theta_hat_copy.