Simple Normalizing Flow Question

Hi, I was looking through the Transform/Normalizing Flow documentation and noticed that we’re supposed to register our transforms in Pyro such as with:

pyro.module("my_transform", my_transform)

Then I saw there are helper functions (TransformFactories) that can create some of these transforms easier with fewer arguments (such as, not requiring us to explicitly create a separate Densenet). But there’s nothing in the documentation about registering these transforms with Pyro (and the Deep Markov Model tutorial doesn’t seem to register it with Pyro either). However, when looking at the source code of these helper functions, I noticed they don’t register them in the param store either.

So I experimented with a simple model and guide that uses normalizing flows (below), and it seems to perform the same regardless of whether or not I register the normalizing flow/transform parameters in Pyro. So I’m a little confused about if Pyro is just keeping the transforms as PyTorch parameters and optimizing both PyTorch and Pyro parameters at the same time? Or why would the model perform the same whether or not I register the normalizing flow parameters in the Pyro param store?

I pasted the model, guide, and training code below.

def flow_regression(is_cont_africa, ruggedness, log_gdp=None):
    # Priors.
    intercept = pyro.sample('intercept', dist.Normal(0., 10.))
    betas = pyro.sample('betas', dist.Normal(0., 1.).expand([3]).to_event(1))
    sigma = pyro.sample('sigma', dist.Uniform(0., 10.))
    # Linear model.
    mean = intercept + betas[0] * is_cont_africa + betas[1] * ruggedness + betas[2] * is_cont_africa * ruggedness
    # Likelihood.
    with pyro.plate('data', len(ruggedness)):
        return pyro.sample('obs', dist.Normal(mean, sigma), obs=log_gdp)

def flow_guide(is_cont_africa, ruggedness, log_gdp=None):
    # Normalizing flows.
    flows = neural_autoregressive(input_dim=3)
    pyro.module('flows', flows)  # <<<--- Do I Need This???
    # Variational parameters.
    intercept_loc = pyro.param('intercept_loc', lambda: torch.tensor(5.))
    intercept_scale = pyro.param('intercept_scale', lambda: torch.tensor(0.5), constraint=constraints.positive)
    weights_loc = pyro.param('weights_loc', lambda: torch.randn(3))
    weights_scale = pyro.param('weights_scale', lambda: torch.ones(3), constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.), constraint=constraints.positive)
    # Variational distributions.
    intercept = pyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale))
    betas = pyro.sample('betas', TransformedDistribution(dist.Normal(weights_loc, weights_scale), flows))
    sigma = pyro.sample('sigma', dist.Normal(sigma_loc, torch.tensor(0.05)))
    return {'intercept': intercept, 'betas': betas, 'sigma': sigma}

clip_adam = pyro.optim.ClippedAdam({'lr': 0.01})
elbo = pyro.infer.Trace_ELBO(num_particles=5)
svi = pyro.infer.SVI(flow_regression, flow_guide, clip_adam, elbo)

losses = []
for step in range(1001):
    loss = svi.step(is_cont_africa, ruggedness, log_gdp)
    if step % 100 == 0:
        print(f"Epoch {step} - Elbo loss: {loss}")

yes, regardless of how you make the flow it needs to be registered (it’s a module). if you don’t register it its parameters should not be updated.

Ok, that makes sense. I was thinking there must be a mistake in the Deep Markov Model tutorial then because I noticed its code below

self.iafs = [affine_autoregressive(z_dim, hidden_dims=[iaf_dim]) for _ in range(num_iafs)]
self.iafs_modules = nn.ModuleList(self.iafs)

appears to create a sequence of transforms and then collect them in a ModuleList. But it does not appear to register the ModuleList. And it does all this in a torch.nn.Module and not a PyroModule class, so it wouldn’t automatically be registered.

But then I just now noticed both the model() and guide() functions in the class have

pyro.module('dmm', self)

which makes me think that must be registering the normalizing flows as well as all the other PyTorch attributes within the class. So I think I got that cleared up now, thanks!

which makes me think that must be registering the normalizing flows as well as all the other PyTorch attributes within the class.

yes that’s right. note that it’s not necessary to register things in both the model and guide as long as you register at least once