The usage of PyroModule

Hi @enemis and @kek,

thanks for trying out the new PyroModule and I apologize for the work involved in porting older code.

Re: purpose of PyroModule, our main motivation was to support serialization and jitting to make it easier to serve trained models. We’ve found that PyroModule allows us to use torch.jit.trace_module without needing to serialize the Pyro param store; the saved model can then be loaded in C++ with torch::jit::load.

Re: usage, PyroModule is still very fresh and we’re open to adding design patterns and new functionality. One constraint in using PyroModule is that submodules must also inherit from PyroModule if they are to be treated in a Bayesian fashion (actually the constraint is a little more complex: the root module must be a PyroModule and the attr path from the root module down to any Pyro*() attribute must be entirely subclasses of PyroModule. @enemis I had not thought about your .named_parameters() automation of random_module, so let me know if you find a good way to automate this in PyroModule; I will think about this as well. One design pattern I have used is to define types in the network I aim to Bayesianize:

class Network(nn.Module):
    Linear = nn.Linear  # this can be overridden by derived classes
    def __init__(self, in_features=2, out_features=1):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = self.Linear(self.in_features, self.out_features)

class RandomNetwork(Network, PyroModule):
    Linear = PyroModule[nn.Linear]
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.linear.weight = PyroSample(
            lambda self: dist.Normal(0, 1)
                             .expand([self.out_features,
                                      self.in_features])
                             .to_event(2))

rand_net = RandomNetwork(2,1)

Again, let us know if you find good automation patterns :smile:
-Fritz