I’ve been experimenting with Pyro for only a few days, but I wanted to share some personal considerations before going on holiday.
– Why is it necessary to to manually specify independence with pyro.iarange()
and friends?
This could be accomplished automatically, either via a tree-like structure or by superclassing Variable
, and propagating the dependencies so that pyro.sample()
statements have the required information to establish the set of parents for the random variable in question.
I’m probably missing something here, are there performance issues with this kind of approaches?
– It’s possible to write code that doesn’t raise errors or warnings but results in incorrect inference.
For instance, at some point I refactored the construction of a specific kind of guide distribution into a function.
The code in question contained calls to pyro.param()
, and I ended up calling this function and storing the returned guide distribution as data member during initialization time (of the object exposing model()
and guide()
).
This didn’t produce any error or warning but the ELBO didn’t get optimized correctly.
This kind of issues are possible in PyTorch too (wrt. correct backpropagation), but this seems to be more of a problem in Pyro, and probably linked to the use of strings for parameters identification (I have been following the issue on this topic on GitHub, and I’m too not 100% convinced that the current approach is robust).
– More specifically, what parts of the model and the guide are required to be created inside of model()
and guide()
functions passed to SVI()
and which parts can be safely cached outside?
That’s not immediately obvious, for instance the examples in the tutorials usually create PyTorch Variable
s (with require_gradient
set to True
) at every call of guide()
, while one might assume that there would be a need to create them once only, that’s what you would do when optimizing a loss in PyTorch.
– I would welcome integration with Pytorch’s DataLoader
, which is the standard mechanism of batching in PyTtorch.
I couldn’t find information on this in the tutorials, and it’s not clear to me wether it’s supported.
– I would welcome “debug” info for the algorithm used in inference when calling SVI.step()
.
It’s nice to have a framework that automatically takes care of various parts of inference, it’s not always evident what is happening under the hood.
For example, for Gamma variables as guide it’s still possible to rely on a parametrization trick (via approximate inverse cdf or the generalized reparametrization trick), but the user does’t know whether this is the case or not.
Thank you for open-sourcing your work, I’m looking forward to future developments.