Thank you for the response @fritzo
I’ve tried out this method, but it does not seem like it’s sampling anything…
Am I missing something?
This is the simple code where I use Pyro.
device = torch.device('cuda')
# Load pre-trained models
net = AlexNet()
net = torch.load('alexnet').to(device)
# Convert to Pyro modules
to_pyro_module_(net)
# Iterate through the modules, and convert these to pyro.sample attributes
for name, module in net.named_parameters():
module = PyroSample(dist.Normal(
loc=module.detach(), scale=args.weight_decay
).to_event(module.dim()))
hmc_kernel = HMC(net, step_size=0.0855, num_steps=4)
mcmc = MCMC(hmc_kernel, num_samples=100, warmup_steps=100)
for data, target in train_loader:
data, target = data.to(device), target.to(device)
mcmc.run(data)
print(mcmc.get_samples())
The pyro module is like this:
PyroAlexNet(
(conv): PyroSequential(
(0): PyroConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
(1): PyroReLU(inplace=True)
(2): PyroMaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): PyroConv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
(4): PyroReLU(inplace=True)
(5): PyroMaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): PyroConv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): PyroReLU(inplace=True)
(8): PyroConv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): PyroReLU(inplace=True)
(10): PyroConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): PyroReLU(inplace=True)
(12): PyroMaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): PyroAdaptiveAvgPool2d(output_size=(6, 6))
(classifier): PyroSequential(
(0): PyroDropout(p=0.5, inplace=False)
(1): PyroLinear(in_features=9216, out_features=4096, bias=True)
(2): PyroReLU(inplace=True)
(3): PyroDropout(p=0.5, inplace=False)
(4): PyroLinear(in_features=4096, out_features=4096, bias=True)
(5): PyroReLU(inplace=True)
(6): PyroLinear(in_features=4096, out_features=10, bias=True)
)
)