Pyro.condition with contexts (plate, scale, etc.)

In my guide, in order to calculate the Hessian, I’m using the following code:

        hess_center = pyro.condition(model,transformed_hat_data)
        mytrace = poutine.block(poutine.trace(hess_center).get_trace)(data, scale, include_nuisance)
        log_posterior = mytrace.log_prob_sum()
        neg_big_hessian, big_grad = myhessian.arrowhead_hessian(log_posterior, theta_parts,
                    len(theta_hat_data), #tensors, not elements=tlen
                    blocksize,
                    return_grad=True)

(think of myhessian.arrowhead_hessian as just hessian.hessian)

When the model includes plate and/or scale statements, I get errors related to log_prob shape on line 3 (log_posterior = mytrace.log_prob_sum()). I think that’s because I’m not doing the conditioning correctly in line 1. I’ve tried fiddling with this in various ways and the errors change but I can’t get it to actually work. Any help? Thanks in advance!

Post the error?

The current error is:
ValueError: Error while computing log_prob_sum at site ‘y’:
Value is not broadcastable with batch_shape+event_shape: torch.Size([68, 3]) vs torch.Size([17, 4, 3]).

I’ve also gotten “invalid log_prob shape” in the past.

The current code on the model side includes:

    all_ps_plate = pyro.plate('all_ps',P)
    @contextlib.contextmanager
    def all_ps():
        with all_ps_plate as p:#, poutine.scale(scale=scale) as pscale:
            yield p

and then later:

    with all_ps() as p_tensor:#pyro.plate('precinctsm2', P):
        y = pyro.sample(f"y",
                        dist.Multinomial(1000,logits=logits).to_event(1))

Any idea what’s going on? If I need to post more, what should that consist of?

If I can’t fix this legitimately, I’m considering diving into the pyro code and just replace the raise calls with almost-hard-coded tensor reshaping. In other words, the ugliest of hacks. I hope it doesn’t come to that.

In other words: “please help me now, or the dog gets it.” :sweat_smile: