Hi, I was looking through the Transform/Normalizing Flow documentation and noticed that we’re supposed to register our transforms in Pyro such as with:
pyro.module("my_transform", my_transform)
Then I saw there are helper functions (TransformFactories) that can create some of these transforms easier with fewer arguments (such as, not requiring us to explicitly create a separate Densenet). But there’s nothing in the documentation about registering these transforms with Pyro (and the Deep Markov Model tutorial doesn’t seem to register it with Pyro either). However, when looking at the source code of these helper functions, I noticed they don’t register them in the param store either.
So I experimented with a simple model and guide that uses normalizing flows (below), and it seems to perform the same regardless of whether or not I register the normalizing flow/transform parameters in Pyro. So I’m a little confused about if Pyro is just keeping the transforms as PyTorch parameters and optimizing both PyTorch and Pyro parameters at the same time? Or why would the model perform the same whether or not I register the normalizing flow parameters in the Pyro param store?
I pasted the model, guide, and training code below.
def flow_regression(is_cont_africa, ruggedness, log_gdp=None):
# Priors.
intercept = pyro.sample('intercept', dist.Normal(0., 10.))
betas = pyro.sample('betas', dist.Normal(0., 1.).expand([3]).to_event(1))
sigma = pyro.sample('sigma', dist.Uniform(0., 10.))
# Linear model.
mean = intercept + betas[0] * is_cont_africa + betas[1] * ruggedness + betas[2] * is_cont_africa * ruggedness
# Likelihood.
with pyro.plate('data', len(ruggedness)):
return pyro.sample('obs', dist.Normal(mean, sigma), obs=log_gdp)
def flow_guide(is_cont_africa, ruggedness, log_gdp=None):
# Normalizing flows.
flows = neural_autoregressive(input_dim=3)
pyro.module('flows', flows) # <<<--- Do I Need This???
# Variational parameters.
intercept_loc = pyro.param('intercept_loc', lambda: torch.tensor(5.))
intercept_scale = pyro.param('intercept_scale', lambda: torch.tensor(0.5), constraint=constraints.positive)
weights_loc = pyro.param('weights_loc', lambda: torch.randn(3))
weights_scale = pyro.param('weights_scale', lambda: torch.ones(3), constraint=constraints.positive)
sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.), constraint=constraints.positive)
# Variational distributions.
intercept = pyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale))
betas = pyro.sample('betas', TransformedDistribution(dist.Normal(weights_loc, weights_scale), flows))
sigma = pyro.sample('sigma', dist.Normal(sigma_loc, torch.tensor(0.05)))
return {'intercept': intercept, 'betas': betas, 'sigma': sigma}
clip_adam = pyro.optim.ClippedAdam({'lr': 0.01})
elbo = pyro.infer.Trace_ELBO(num_particles=5)
svi = pyro.infer.SVI(flow_regression, flow_guide, clip_adam, elbo)
pyro.clear_param_store()
losses = []
for step in range(1001):
loss = svi.step(is_cont_africa, ruggedness, log_gdp)
losses.append(loss)
if step % 100 == 0:
print(f"Epoch {step} - Elbo loss: {loss}")