I have a custom backward pass operator that computes parameter gradients through time series using the adjoint method. This code works fine for deterministic optimization just with pytorch.
Using pyro, at least for the moment, I’m converting over each parameter in the original model to a distribution and using an autoguide.
One thing you need to do in the adjoint operator is determine which parameters to propagate gradients for in the backward pass. For pytorch it’s sufficient to just use model.parameters()
to generate this list.
After converting the model to pyro model.parameters()
, correctly, returns nothing as the parameters are all associated with the guide.
- Is there a way to determine which parameters are associated with a model with introspection in some way?
- Will I need access to the guide to do this?
- In general, is there a way to do AD over a pyro model with respect to guide parameters?
You need #3 even for the adjoint method because you need to get “local” parameter gradients at fixed state for the adjoint pass, for which it’s nice to use AD.