How does pyro.random_module match priors with RegressionModel parameters?

Hi, I’m following the Bayesian regression tutorial.
I thought that names of priors (here: priors = {‘linear.weight’: w_prior, ‘linear.bias’: b_prior}) have to correspond to parameters in RegressionModel so that they can be matched, but it appears that using other names works as well (as long as they are consistent between model and guide). So how does Pyro match priors with model parameters?
Thanks.

no the names need to match… can you give an example in which misnaming parameters works?

Hi, thanks for considering my question. I looked into it some more and can see that names need to match but I’m still a bit confused about what is happening when names don’t match. The original output of the Bayesian regression tutorial is:

[iteration 0001] loss: 451.4380
[iteration 0101] loss: 9.5868
[iteration 0201] loss: 1.7982
[iteration 0301] loss: -0.7607
[iteration 0401] loss: -1.2017
[iteration 0501] loss: -1.2463
[iteration 0601] loss: -1.2553
[iteration 0701] loss: -1.2556
[iteration 0801] loss: -1.2098
[iteration 0901] loss: -1.1862
[guide_mean_weight]: 2.995
[guide_log_scale_weight]: -3.917
[guide_mean_bias]: 0.994
[guide_log_scale_bias]: -4.138

If I rename priors/dists to ‘foo’ and ‘bar’ I get:

[iteration 0001] loss: 417.5909
[iteration 0101] loss: 1.0944
[iteration 0201] loss: -0.7607
[iteration 0301] loss: -1.2853
[iteration 0401] loss: -1.3700
[iteration 0501] loss: -1.3785
[iteration 0601] loss: -1.3790
[iteration 0701] loss: -1.3790
[iteration 0801] loss: -1.3790
[iteration 0901] loss: -1.3790
[guide_mean_weight]: 0.728
[guide_log_scale_weight]: -2.993
[guide_mean_bias]: 0.814
[guide_log_scale_bias]: -3.105
[module$$$linear.weight]: 2.991
[module$$$linear.bias]: 1.005

Since I’m calling SVI with loss=Trace_ELBO() how is the loss being minimized? Am I just getting a maximum likelihood estimate for module$$$linear.weight and module$$$linear.bias?
Full code below.

import os
import numpy as np
import torch
import torch.nn as nn

import pyro
from pyro.distributions import Normal
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

if __name__ == '__main__':
    # for CI testing
    smoke_test = ('CI' in os.environ)
    pyro.enable_validation(True)

    N = 100  # size of toy data

    def build_linear_dataset(N, p=1, noise_std=0.01):
        X = np.random.rand(N, p)
        # w = 3
        w = 3 * np.ones(p)
        # b = 1
        y = np.matmul(X, w) + np.repeat(1, N) + np.random.normal(0, noise_std, size=N)
        y = y.reshape(N, 1)
        X, y = torch.tensor(X).type(torch.Tensor), torch.tensor(y).type(torch.Tensor)
        data = torch.cat((X, y), 1)
        assert data.shape == (N, p + 1)
        return data
        
    class RegressionModel(nn.Module):
        def __init__(self, p):
            # p = number of features
            super(RegressionModel, self).__init__()
            self.linear = nn.Linear(p, 1)

        def forward(self, x):
            return self.linear(x)

    regression_model = RegressionModel(1)

    def model(data):
        # Create unit normal priors over the parameters
        loc, scale = torch.zeros(1, 1), 10 * torch.ones(1, 1)
        bias_loc, bias_scale = torch.zeros(1), 10 * torch.ones(1)
        w_prior = Normal(loc, scale).independent(1)
        b_prior = Normal(bias_loc, bias_scale).independent(1)
        # priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
        priors = {'foo': w_prior, 'bar': b_prior}
        # lift module parameters to random variables sampled from the priors
        lifted_module = pyro.random_module("module", regression_model, priors)
        # sample a regressor (which also samples w and b)
        lifted_reg_model = lifted_module()
        with pyro.iarange("map", N):
            x_data = data[:, :-1]
            y_data = data[:, -1]

            # run the regressor forward conditioned on data
            prediction_mean = lifted_reg_model(x_data).squeeze(-1)
            # condition on the observed data
            pyro.sample("obs",
                        Normal(prediction_mean, 0.1 * torch.ones(data.size(0))),
                        obs=y_data)
                        
    softplus = torch.nn.Softplus()

    def guide(data):
        # define our variational parameters
        w_loc = torch.randn(1, 1)
        # note that we initialize our scales to be pretty narrow
        w_log_sig = torch.tensor(-3.0 * torch.ones(1, 1) + 0.05 * torch.randn(1, 1))
        b_loc = torch.randn(1)
        b_log_sig = torch.tensor(-3.0 * torch.ones(1) + 0.05 * torch.randn(1))
        # register learnable params in the param store
        mw_param = pyro.param("guide_mean_weight", w_loc)
        sw_param = softplus(pyro.param("guide_log_scale_weight", w_log_sig))
        mb_param = pyro.param("guide_mean_bias", b_loc)
        sb_param = softplus(pyro.param("guide_log_scale_bias", b_log_sig))
        # guide distributions for w and b
        w_dist = Normal(mw_param, sw_param).independent(1)
        b_dist = Normal(mb_param, sb_param).independent(1)
        # dists = {'linear.weight': w_dist, 'linear.bias': b_dist}
        dists = {'foo': w_dist, 'bar': b_dist}
        # overload the parameters in the module with random samples
        # from the guide distributions
        lifted_module = pyro.random_module("module", regression_model, dists)
        # sample a regressor (which also samples w and b)
        return lifted_module()
    
    optim = Adam({"lr": 0.05})
    svi = SVI(model, guide, optim, loss=Trace_ELBO())
    num_iterations = 1000 if not smoke_test else 2
    
    pyro.clear_param_store()
    data = build_linear_dataset(N)
    for j in range(num_iterations):
        # calculate the loss and take a gradient step
        loss = svi.step(data)
        if j % 100 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss / float(N)))
            
    for name in pyro.get_param_store().get_all_param_names():
        print("[%s]: %.3f" % (name, pyro.param(name).data.numpy()))

then random_module is basically doing nothing, none of your module parameters are being lifted to random variables and so it is just running a normal nn (ie there is no connection between your priors and the data). this is evident by your output: the param names module$$$[param] show that the params havent been lifted and they match the MLE estimates of the bayesian version above. in pyro 0.3, you will get a warning thrown that your names dont match.

Terrific, thanks a lot for your help. One more dumb question: how would I change the code if I wanted to keep some of the parameters (say linear.bias in the regression model) constant? To be clear: the final output for module$$$linear.bias should be equal to the initialization of linear.bias and not equal to 1. Can I still use pyro.random_module for this?

p.s. I guess I could use a Delta distribution but this seems a bit ugly…

When we name the priors, can we add the instance name as a prefix, like the following? That is, make sure the prior names are identical between model and guide, and the dimensions are consistent with the variables/tensors that they represent.

def model(data):
    ...
    priors = {'regression_model.linear.weight': w_dist, 'regression_model.linear.bias': b_dist}
    ...
    lifted_module = pyro.random_module("module", regression_model, priors)
    ....

def guide(data):
    ...
    priors = {'regression_model.linear.weight': w_dist, 'regression_model.linear.bias': b_dist}
    ...
    lifted_module = pyro.random_module("module", regression_model, priors)
    ...

I tried this on a convnet, with the convnet instance name as the prefix, and the prediction accuracy is much higher than what I get without the prefix, something like 97% vs 90%. I hope it is OK to deviate the name of the priors from the names in the definition of the convnet :slight_smile:

What is the exact process of connecting the priors with the data? I searched the Pyro documentation and the website, but did not find the details. Thanks.

What is the exact process of connecting the priors with the data?

it uses the name you would get if you printed the names in nn_module.named_parameters():

for name, _ in nn_module.named_parameters():
    print(name)