thanks for trying out the new PyroModule
and I apologize for the work involved in porting older code.
Re: purpose of PyroModule
, our main motivation was to support serialization and jitting to make it easier to serve trained models. We’ve found that PyroModule
allows us to use torch.jit.trace_module
without needing to serialize the Pyro param store; the saved model can then be loaded in C++ with torch::jit::load
.
Re: usage, PyroModule
is still very fresh and we’re open to adding design patterns and new functionality. One constraint in using PyroModule
is that submodules must also inherit from PyroModule
if they are to be treated in a Bayesian fashion (actually the constraint is a little more complex: the root module must be a PyroModule
and the attr path from the root module down to any Pyro*()
attribute must be entirely subclasses of PyroModule
. @enemis I had not thought about your .named_parameters()
automation of random_module
, so let me know if you find a good way to automate this in PyroModule
; I will think about this as well. One design pattern I have used is to define types in the network I aim to Bayesianize:
class Network(nn.Module):
Linear = nn.Linear # this can be overridden by derived classes
def __init__(self, in_features=2, out_features=1):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.linear = self.Linear(self.in_features, self.out_features)
class RandomNetwork(Network, PyroModule):
Linear = PyroModule[nn.Linear]
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features)
self.linear.weight = PyroSample(
lambda self: dist.Normal(0, 1)
.expand([self.out_features,
self.in_features])
.to_event(2))
rand_net = RandomNetwork(2,1)
Again, let us know if you find good automation patterns
-Fritz