I’m trying to modify an existing custom PyTorch Module
to be Bayesian by replacing the weights within the __init__()
call with PyroSample
s. However, I’m getting a naming conflict: RuntimeError: Multiple sample sites named 'weight'
. Is there a way to use PyroSample
whilst resolving this name conflict? For example passing in a name argument?
I saw from this post that I could do this via switching to pyro.sample
statements but this would involve a fair amount of changes to the Module
code that I would like to avoid if possible.
Additionally, more of an aside, is the only way to use the GPU with a PyroSample
or pyro.sample
to pass in a torch.tensor
object located on the GPU already? e.g. PyroSample(dist.Normal(torch.tensor(0., device='cuda:0'), 1.))
did you see this?
http://pyro.ai/examples/modules.html#⚠-Caution:-avoiding-duplicate-names
for the last point can’t you register a buffer
self.register_buffer('prior_mean', torch.tensor(0.0))
and then consistently use that in your PyroSample
statements and then .cuda()/to()
should work?
Thank you, both of these suggestions were helpful!
After reading the naming convention documentation it turns out I had accidentally omitted the PyroModule
mixin from a torch.nn.Sequential
constructor.
And the self.register_buffer
allowed the registered tensor to be moved onto the device as suggested.
I spoke too soon with regards to the second point,
the self.register_buffer
works if the dist
constructor is in the forward
call, e.g. the sigma = pyro.sample('sigma', dist.Uniform(self.sigma_prior_mean, 1.0))
line works correctly.
However, the self.bias_prior_mean
used in the PyroSample
call for self.linear.bias
in the constructor does not correctly move the distribution to the GPU device. Presumably because the self.linear.bias
distribution is already instantiated in the constructor? Since (AFAIK) the PyroSample
statements should be in the __init__
call, is there any way around this?
import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
class BayesianLinear(torch.nn.Linear, PyroModule):
'''Bayesian Linear Layer'''
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__(in_features, out_features, bias)
self.register_buffer('weight_prior_mean', torch.tensor(0.0))
self.register_buffer('bias_prior_mean', torch.tensor(0.0))
self.register_buffer('sigma_prior_mean', torch.tensor(0.0))
self.weight = PyroSample(
dist.Normal(self.weight_prior_mean, 1.0).expand([out_features, in_features]).to_event(2)
)
if bias:
self.bias = PyroSample(
dist.Normal(self.bias_prior_mean, 1.0).expand([out_features]).to_event(1)
)
def forward(self, x, y=None):
sigma = pyro.sample('sigma', dist.Uniform(self.sigma_prior_mean, 1.0))
mean = super().forward(x)
with pyro.plate('data', size=x.shape[0], dim=-2):
obs = pyro.sample('obs', dist.Normal(mean, sigma), obs=y)
return mean
if __name__ == '__main__':
model = BayesianLinear(10, 20)
model = model.to(device='cuda:0')
data = torch.randn(4, 10).to(device='cuda:0')
out = model(data) # This throws a device mismatch