I’m trying to modify an existing custom PyTorch
Module to be Bayesian by replacing the weights within the
__init__() call with
PyroSamples. However, I’m getting a naming conflict:
RuntimeError: Multiple sample sites named 'weight'. Is there a way to use
PyroSample whilst resolving this name conflict? For example passing in a name argument?
I saw from this post that I could do this via switching to
pyro.sample statements but this would involve a fair amount of changes to the
Module code that I would like to avoid if possible.
Additionally, more of an aside, is the only way to use the GPU with a
pyro.sample to pass in a
torch.tensor object located on the GPU already? e.g.
PyroSample(dist.Normal(torch.tensor(0., device='cuda:0'), 1.))
did you see this?
for the last point can’t you register a buffer
and then consistently use that in your
PyroSample statements and then
.cuda()/to() should work?
Thank you, both of these suggestions were helpful!
After reading the naming convention documentation it turns out I had accidentally omitted the
PyroModule mixin from a
self.register_buffer allowed the registered tensor to be moved onto the device as suggested.
I spoke too soon with regards to the second point,
self.register_buffer works if the
dist constructor is in the
forward call, e.g. the
sigma = pyro.sample('sigma', dist.Uniform(self.sigma_prior_mean, 1.0)) line works correctly.
self.bias_prior_mean used in the
PyroSample call for
self.linear.bias in the constructor does not correctly move the distribution to the GPU device. Presumably because the
self.linear.bias distribution is already instantiated in the constructor? Since (AFAIK) the
PyroSample statements should be in the
__init__ call, is there any way around this?
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
class BayesianLinear(torch.nn.Linear, PyroModule):
'''Bayesian Linear Layer'''
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__(in_features, out_features, bias)
self.weight = PyroSample(
dist.Normal(self.weight_prior_mean, 1.0).expand([out_features, in_features]).to_event(2)
self.bias = PyroSample(
def forward(self, x, y=None):
sigma = pyro.sample('sigma', dist.Uniform(self.sigma_prior_mean, 1.0))
mean = super().forward(x)
with pyro.plate('data', size=x.shape, dim=-2):
obs = pyro.sample('obs', dist.Normal(mean, sigma), obs=y)
if __name__ == '__main__':
model = BayesianLinear(10, 20)
model = model.to(device='cuda:0')
data = torch.randn(4, 10).to(device='cuda:0')
out = model(data) # This throws a device mismatch