The usage of PyroModule

rewriting some pyro code and saw that my belowed random_module was being phased out. Now Im trying to understand the pyroModule.

How Ive usually worked is to create a standard pytorch nn.Module and then created the two functions model() and guide() to define priors, likelihood and posteriors. With some help of net.named_parameters() its quite simple to create a mean-field model. See an example e.g. here. A new property of this combined with random_module is that I can get a posterior sample of my model by simply calling the guide functions.

So a general question is how one should build a full system using PyroModule instead of randomModule.

In specific, Im trying to modify a full network instead of a torch layer. However, this seems to break down. I suspect I just dont understand the purpose of PyroModule yet, but what is the best practice here?

TypeError: cannot assign 'pyro.nn.module.PyroSample' as parameter 'weight' (torch.nn.Parameter or None expected)

    class Network(nn.Module):
    def __init__(self, in_features, out_features):
        super(Network, self).__init__()
        self.linear = nn.Linear(self.in_features, self.out_features)

    def forward(self, x, y=None):
        mean = self.linear(x)
        return mean

class RandomNetwork(Network, PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.linear.weight = PyroSample(
            lambda self: dist.Normal(0, 1)

    def model(x,y):
rand_net = RandomNetwork(2,1)

the usage of PyroModule in combination with nn is explained here.
maybe it helps.

Hi Kek,
Ive read this and not really got it. It is also wrt a nn.Linear, not a generic neural net with parameters… My example above is a nn.Linear inside a network…

I checked your code and changed nn.Linear to PyroModule[nn.Linear]. I also adapted the class inheritance of Network so PyroModule is able to handle the linear layer as a Pyro module.
Now it should be possible to assign the network weight parameters with pyro.
Hope this fixed your problem.

from torch import nn
from pyro.nn import PyroSample, PyroModule
import pyro.distributions as dist
import torch

class Network(nn.Linear, PyroModule):
    def __init__(self, in_features, out_features):
        super(Network, self).__init__(in_features, out_features)
        self.linear = PyroModule[nn.Linear](self.in_features, self.out_features)

Hi @enemis and @kek,

thanks for trying out the new PyroModule and I apologize for the work involved in porting older code.

Re: purpose of PyroModule, our main motivation was to support serialization and jitting to make it easier to serve trained models. We’ve found that PyroModule allows us to use torch.jit.trace_module without needing to serialize the Pyro param store; the saved model can then be loaded in C++ with torch::jit::load.

Re: usage, PyroModule is still very fresh and we’re open to adding design patterns and new functionality. One constraint in using PyroModule is that submodules must also inherit from PyroModule if they are to be treated in a Bayesian fashion (actually the constraint is a little more complex: the root module must be a PyroModule and the attr path from the root module down to any Pyro*() attribute must be entirely subclasses of PyroModule. @enemis I had not thought about your .named_parameters() automation of random_module, so let me know if you find a good way to automate this in PyroModule; I will think about this as well. One design pattern I have used is to define types in the network I aim to Bayesianize:

class Network(nn.Module):
    Linear = nn.Linear  # this can be overridden by derived classes
    def __init__(self, in_features=2, out_features=1):
        self.in_features = in_features
        self.out_features = out_features
        self.linear = self.Linear(self.in_features, self.out_features)

class RandomNetwork(Network, PyroModule):
    Linear = PyroModule[nn.Linear]
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.linear.weight = PyroSample(
            lambda self: dist.Normal(0, 1)

rand_net = RandomNetwork(2,1)

Again, let us know if you find good automation patterns :smile:

Hi @fritzo,
thank you for your reply. I prefer progress over legacy, so please go ahead upgrading, great job! (I did find it a bit amusing to depreciate a function in favour of an experimental one :wink: )

Over the last months, I have made quite some hacks to productionize some pyro models. Using a separate nn.Module and .random_module() worked quite nice to jit.trace a (single) posterior sample. I then tried to figure out how to easily sample during jit, whereas I believe you have some good ideas in PyroModule (and I can throw some of mine in the bin)!

To fix my issues in this thread, Ive made a helper function to add simple noninformative priors to my modules (e.g. rewrote the example above). I believe the solution is very simple to your proposed pr (great job):

def set_noninform_prior(mod, scale = 1.0):
    for key, par in list(mod.named_parameters()):
        setattr(mod, key, PyroSample(dist.Normal(torch.zeros_like(par), scale*torch.ones_like(par)).independent() ))

I found a couple of other questions working on pyroModule, would love feedback (if you have time):

  • Although I havent done any speedchecks yet, does it hurt speed-wise to serve a model through the model() function? In reality, I do not want to create a distribution for the data, i.e. the whole with pyro.plate("data"): ... section should not be needed? Or will this drop when tracing it?
  • Related to point above. Ive currently implemented a simple predictive function (for one sample). Would this be the recommended way to do it? (like above, full example in updated file here):
    def predict(self, batch, guide_trace = None):
        if guide_trace is None:
            guide_trace = self.sample_guide_trace(batch)

        model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(batch)
        return model_trace.nodes['_RETURN']['value']

Hi @enemis, nice! I see your set_nonuniform_prior() is very similar to the example in to_pyro_module_().

You should be able to simplify predict() method by omitting the model tracing:

def predict(self, batch, guide_trace=None):
    if guide_trace is None:
        guide_trace = self.sample_guide_trace(batch)
    with poutine.replay(trace=guide_trace):
        return self.model(batch)

Regarding speed, I can’t make specific recommendations without seeing your model code. I generally add a flag like forecast=False to my model to switch between training mode and prediction model, where a big likelihood term is gated by the flag. Sometimes you can even get away omitting the model entirely, e.g. if you only need to predict latent variables, you can call the guide without needing to replay the model. If you write your guide and model carefully, you can make predictions without using any Pyro code, i.e. use Pyro only for training.

1 Like

I’ll jump in to this thread as I have some of the same questions. I think I am confused about when to use PyroModule and when not to (or when it is necessary).

It seems that in my VAE, the encoder/decoder networks shouldn’t be PyroModules because if they are then the neural networks don’t learn any weights. But my regression module should, I think, because it was using random_module before?

Apologies if I’m asking silly questions, I was just spending a long time in standard PyTorch and my brain might have forgotten how all of this works…

hi @jtw, you can use, as far as I understand (correct me if im wrong @fritzo) , PyroModule in both of these cases. However, you need to be careful where you place your PyroParam() and PyroSample().

Basically, if you have a bayesian parameter (i.e. a model parameters), you need to use PyroSample() to make it a model parameter. I guess this is what you want to use in your regression.

If you have variational parameters (e.g. the encoder and decoder in a vae) you should use PyroParam because these are variational parameters. The latent state of the vae is however still a model parameter, and you need to do pyro.sample() on this in the model().

Hi @jtw (and @enemis and @kek), Thanks for trying out PyroModule! First a meta question:

What kind of documentation would you find most useful? Maybe a PyroModule tutorial, or more content in the docstring, or maybe a "10 things to remember about PyroModule listicle, or something else?

when to use PyroModule and when not to

Use PyroModule to add Pyro effects to an nn.Module. The two pyro effects are (1) pyro.param statements (PyroModule = nn.Module + pyro.module() statements), and(2) pyro.sample statements when you want to be Bayesian about a module. I have started using PyroModule for each of model and guide so I don’t have to write any pyro.module() or pyro.param() statements and so I can use constrained parameters. Sometimes I wrap those model and guide in a single nn.Module so I can torch.jit.trace_module the whole thing (see pyro.contrib.cevae.CEVAE for example).

Also an update on @enemis’s original question: as of Pyro 1.1 you can now use to_pyro_module_():

class Network(nn.Module):  # <--- not a PyroModule
    # defined above...

rand_net = Network(2, 1)

# Now we can convert rand_net to a PyroModule in-place
from pyro.nn.module import to_pyro_module_
to_pyro_module_(rand_net)  # <--- now rand_net is a PyroModule
rand_net.linear.weight = PyroSample(
    lambda self: dist.Normal(0, 1)
                     .expand([self.out_features, self.in_features])

Again, if you have other ideas for design patterns, let us know!

The CEVAE example is useful as it’s closer to the architecture of my existing code, which looks more like the original VAE tutorial.

I think the main source of confusion for me was that I thought using PyroModule meant that I’d need a prior for every variable in that module (e.g. all the weights in my encoder/decoder networks), but really it’s just a signal that some of the variables will have priors/be sampled during training (e.g. the latent space).

In terms of useful documentation, I think just converting the existing examples would be a good start. Maybe showing before/after, so that people like me can see how to convert code. And a short summary of the rules (I guess this is the listicle idea…).

  • The submodule rules you mentioned above: how far you need to percolate PyroModule down (or up) to have it work as intended.
  • What PyroModule means for the variables in the module (± using PyroSample). It seems like the answer is straightforward, just wasn’t obvious to me.

Looking at the “Mixins” section of Neural Networks, I think my confusion was coming from a combination of those two points, as illustrated by this code snippet:

model = PyroModule[nn.Sequential](
    PyroModule[nn.Linear](28 * 28, 100),
    PyroModule[nn.Linear](100, 100),
    PyroModule[nn.Linear](100, 10),
assert isinstance(model, nn.Sequential)
assert isinstance(model, PyroModule)

# Now we can be Bayesian about weights in the first layer.
model[0].weight = PyroSample(
    prior=dist.Normal(0, 1).expand([28 * 28, 100]).to_event(2))
guide = AutoDiagonalNormal(model)

In this example, every nn.Module is being wrapped, but only the first linear layer is getting a prior. I didn’t really understand that until now, so I didn’t get what was happening. As I understand it now, most of those mixin statements are basically no-ops, except that they allow the module as a whole to function?


What kind of documentation would you find most useful? Maybe a PyroModule tutorial

Tutorials and examples! :slight_smile: Ive found the current examples to include alot of extra stuff that is not needed to understand the basics. for example the easyguide example where I think I get confused because there is time series mixed in :slight_smile: its great to get more complexity over time ,but in the beginning it would be great to understand how they work on standard regression problems etc. I would be happy to contribute an example or two if

For pyroModule in specific it would be great to understand a couple of “best practices” on how you intended to use them. For example, I did not think of using it for defining the guide :slight_smile: Wrapping model+guide in a nn.Module is also a very interesting use-case!

Realizing Im just listing my issues, I would be happy to contribute an example or two when I get up to speed on how the framework should be interpreted.