Is there any way to sample parameters from pre-trained Pytorch models?

This might be a weird question, but I have been looking into Pyro for some time, but could not find a solution.
Currently, I have a pre-trained Pytorch model, and I would like to sample weight parameters using pyro.infer.HMC. Because the model is trained, I am planning on using the loaded weights as the mean of each weight, and L2 reg as the prior variance.

However, in Pyro it seems like I have to define a probabilistic model, and train it before I can apply hmc sampling. This does not utilize the pre-trained models. Is there any way I can use my pre-trained model, and apply HMC sampling?

I would like to do something like this:

class Net(nn.Modules):
  ...

net = Net()
net.load_state_dict(path)

hmc_kernel = HMC(net, step_size=0.01, num_steps=10)
mcmc = MCMC(hmc_kernel, num_samples=500, warmup_steps=100)
for data, target in data_loader:
  mcmc.run(data)
  mcmc.get_samples()

Thank you in advance.

Hi @sff1019,
interesting use case! I believe you can achieve this using PyroModule and to_pyro_module_(). Something like this should work:

# First load a pre-trained module.
class Net(nn.Module):
    ...

net = Net()
net.load_state_dict(path)

# Now convert the existing module to a PyroModule in-place.
from pyro.nn.module import to_pyro_module_, PyroSample
to_pyro_module_(net)

# At this point the module will effectively have pyro.param attributes.
# To start being Bayesian we convert these to pyro.sample attributes.
# I am unsure how to do this generically, so let's convert each attribute:
net.weight = PyroSample(dist.Normal(loc=net.weight.detach(),
                                    scale=my_l2_regularizer)
                            .to_event(net.weight.dim()))
net.bias = PyroSample(dist.Normal(loc=net.bias.detach(),
                                  scale=my_l2_regularizer)
                          .to_event(net.bias.dim()))
...

At this point you should be able to run HMC or SVI or whatever

hmc_kernel = HMC(net, ...)

Let me know how this works out,
Fritz

Thank you for the response @fritzo

I’ve tried out this method, but it does not seem like it’s sampling anything…
Am I missing something?

This is the simple code where I use Pyro.

device = torch.device('cuda')

# Load pre-trained models
net = AlexNet()
net = torch.load('alexnet').to(device)

# Convert to Pyro modules
to_pyro_module_(net)

# Iterate through the modules, and convert these to pyro.sample attributes
for name, module in net.named_parameters():
    module = PyroSample(dist.Normal(
        loc=module.detach(), scale=args.weight_decay
    ).to_event(module.dim()))

hmc_kernel = HMC(net, step_size=0.0855, num_steps=4)
mcmc = MCMC(hmc_kernel, num_samples=100, warmup_steps=100)

for data, target in train_loader:
    data, target = data.to(device), target.to(device)
    mcmc.run(data)
    print(mcmc.get_samples())

The pyro module is like this:

PyroAlexNet(
  (conv): PyroSequential(
    (0): PyroConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (1): PyroReLU(inplace=True)
    (2): PyroMaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): PyroConv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
    (4): PyroReLU(inplace=True)
    (5): PyroMaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): PyroConv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): PyroReLU(inplace=True)
    (8): PyroConv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): PyroReLU(inplace=True)
    (10): PyroConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): PyroReLU(inplace=True)
    (12): PyroMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): PyroAdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): PyroSequential(
    (0): PyroDropout(p=0.5, inplace=False)
    (1): PyroLinear(in_features=9216, out_features=4096, bias=True)
    (2): PyroReLU(inplace=True)
    (3): PyroDropout(p=0.5, inplace=False)
    (4): PyroLinear(in_features=4096, out_features=4096, bias=True)
    (5): PyroReLU(inplace=True)
    (6): PyroLinear(in_features=4096, out_features=10, bias=True)
  )
)

Your automation is throwing away its results:

for name, module in net.named_parameters():
    module = PyroSample(dist.Normal(    # <-- this has no effect
        loc=module.detach(), scale=args.weight_decay
    ).to_event(module.dim()))

To automate that you’d need something like a deep_setattr() helper

for name, param in net.named_parameters():
    param = PyroSample(dist.Normal(
        loc=param.detach(), scale=args.weight_decay
    ).to_event(param.dim()))
    deep_setattr(net, name, param)

where something like this might work

def deep_setattr(obj, name, value):
    parts = name.split(".")
    for part in parts[:-1]:
        obj = getattr(obj, part)
    setattr(obj, parts[-1], value)