In my guide, in order to calculate the Hessian, I’m using the following code:
hess_center = pyro.condition(model,transformed_hat_data) mytrace = poutine.block(poutine.trace(hess_center).get_trace)(data, scale, include_nuisance) log_posterior = mytrace.log_prob_sum() neg_big_hessian, big_grad = myhessian.arrowhead_hessian(log_posterior, theta_parts, len(theta_hat_data), #tensors, not elements=tlen blocksize, return_grad=True)
(think of myhessian.arrowhead_hessian as just hessian.hessian)
When the model includes plate and/or scale statements, I get errors related to log_prob shape on line 3 (
log_posterior = mytrace.log_prob_sum()). I think that’s because I’m not doing the conditioning correctly in line 1. I’ve tried fiddling with this in various ways and the errors change but I can’t get it to actually work. Any help? Thanks in advance!