Using constraints within an nn.Module

Dear Pyro team,

I’m in the process of packaging the implementation of my model as a nn.Module, but am unsure as to how to manage constraints in this context; or, to put it another way, if there’s an equivalent to pyro.param("name", value, constraint=constraint) that I can use within a module.

For example, if my module has a covariance matrix then I would obviously like to be able to use constraint=constraints.lower_cholesky. However, the tutorials (VAE, Deep Markov Model, Bayesian Regression) use e.g. nn.Softplus() to ensure positivity on all their scale parameters rather than the constraints library (which, admittedly, is also because of the amortisation architecture in some of the instances). But for more complex constraints, like lower_cholesky, this isn’t straightforward.

As I understand it I have two options:

  • Keep the unconstrained parameters as the nn.Parameter class members, and keep track of the transformations manually (or by reusing a few key bits of logic from param_store.py).

  • Simply prepend an instance specific string to the parameter names, and keep relying on the param store to manage them and their constraints (i.e. pyro.param(unique_id + "name", value, constraint=constraint)). That avoids name clashes for multiple instances of the portion of the model wrapped up as a module.

The former sticks much more closely in spirit to the aim of repurposing the code as an nn.Module, in that parameters would behave as expected when saving, registering with Pyro, calling .cuda(), etc. However, it seems like there might be a better way to use the existing constraint management framework in Pyro. Does that make sense, and have I missed an obvious way to avoid reinventing the wheel in terms of keeping track of parameter / constraint pairs?

Many thanks!!

The constraint transforms are purely functional, so you should be able to use them in an nn.Module via transform_to(), which is what Pyro does under the hood:

from torch.distributions import constraints, transform_to

class MyModule(nn.Module):
    ...
    def forward(self, x):
        ...
        unconstrained_foo = ...
        foo = transform_to(constraints.lower_cholesky)(unconstrained_foo)
        ...

Perfect, many thanks for the response! Yes, that solves the problem of using the constraints.

The other side of things is keeping track of what needs to be transformed - what’s nice about pyro.param is that the constraint is given once and then stored. Whereas in an nn.Module every time you access self.unconstrained_foo you need to remember to apply the constraint. What I was sanity checking was that I hadn’t missed some Pyro magic that makes that (admittedly not particularly onerous ;-p) step easier by e.g. overriding __get__ in an nn.Module to apply constraints in a way that’s analogous to how the param store overrides __getitem__ and friends.

Thanks again :slight_smile: