Checkpoint in pyro?

Is it possible to use the new checkpointing feature from pytorch in pyro? I get an error that says:

RuntimeError: Checkpointing is not compatible with .grad(), please use .backward() if possible

The SVI code uses backward instead of torch.autgorad.grad for computing gradients, so it should be compatible with checkpointing. Other algorithms inc. the leapfrog integrator in HMC will unfortunately not work because checkpointing does not support .grad.

Why do I always try to do things that aren’t possible :frowning:

@kkyang you should open a PyTorch issue about using checkpoint with grad, it seems like a generally useful enhancement.

1 Like