In my guide, in order to calculate the Hessian, I’m using the following code:
hess_center = pyro.condition(model,transformed_hat_data)
mytrace = poutine.block(poutine.trace(hess_center).get_trace)(data, scale, include_nuisance)
log_posterior = mytrace.log_prob_sum()
neg_big_hessian, big_grad = myhessian.arrowhead_hessian(log_posterior, theta_parts,
len(theta_hat_data), #tensors, not elements=tlen
blocksize,
return_grad=True)
(think of myhessian.arrowhead_hessian as just hessian.hessian)
When the model includes plate and/or scale statements, I get errors related to log_prob shape on line 3 (log_posterior = mytrace.log_prob_sum()
). I think that’s because I’m not doing the conditioning correctly in line 1. I’ve tried fiddling with this in various ways and the errors change but I can’t get it to actually work. Any help? Thanks in advance!