Is it possible to use the new checkpointing feature from pytorch in pyro? I get an error that says:
RuntimeError: Checkpointing is not compatible with .grad(), please use .backward() if possible
Is it possible to use the new checkpointing feature from pytorch in pyro? I get an error that says:
RuntimeError: Checkpointing is not compatible with .grad(), please use .backward() if possible
The SVI
code uses backward
instead of torch.autgorad.grad
for computing gradients, so it should be compatible with checkpointing. Other algorithms inc. the leapfrog integrator in HMC will unfortunately not work because checkpointing does not support .grad
.
Why do I always try to do things that aren’t possible
@kkyang you should open a PyTorch issue about using checkpoint
with grad
, it seems like a generally useful enhancement.