Naming conflict when using PyroSample

I’m trying to modify an existing custom PyTorch Module to be Bayesian by replacing the weights within the __init__() call with PyroSamples. However, I’m getting a naming conflict: RuntimeError: Multiple sample sites named 'weight'. Is there a way to use PyroSample whilst resolving this name conflict? For example passing in a name argument?

I saw from this post that I could do this via switching to pyro.sample statements but this would involve a fair amount of changes to the Module code that I would like to avoid if possible.

Additionally, more of an aside, is the only way to use the GPU with a PyroSample or pyro.sample to pass in a torch.tensor object located on the GPU already? e.g. PyroSample(dist.Normal(torch.tensor(0., device='cuda:0'), 1.))

did you see this?

http://pyro.ai/examples/modules.html#⚠-Caution:-avoiding-duplicate-names

for the last point can’t you register a buffer
self.register_buffer('prior_mean', torch.tensor(0.0))
and then consistently use that in your PyroSample statements and then .cuda()/to() should work?

Thank you, both of these suggestions were helpful!

After reading the naming convention documentation it turns out I had accidentally omitted the PyroModule mixin from a torch.nn.Sequential constructor.

And the self.register_buffer allowed the registered tensor to be moved onto the device as suggested.

I spoke too soon with regards to the second point,

the self.register_buffer works if the dist constructor is in the forward call, e.g. the sigma = pyro.sample('sigma', dist.Uniform(self.sigma_prior_mean, 1.0)) line works correctly.

However, the self.bias_prior_mean used in the PyroSample call for self.linear.bias in the constructor does not correctly move the distribution to the GPU device. Presumably because the self.linear.bias distribution is already instantiated in the constructor? Since (AFAIK) the PyroSample statements should be in the __init__ call, is there any way around this?

import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

class BayesianLinear(torch.nn.Linear, PyroModule):
    '''Bayesian Linear Layer'''

    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__(in_features, out_features, bias)
        self.register_buffer('weight_prior_mean', torch.tensor(0.0))
        self.register_buffer('bias_prior_mean', torch.tensor(0.0))
        self.register_buffer('sigma_prior_mean', torch.tensor(0.0))
        self.weight = PyroSample(
            dist.Normal(self.weight_prior_mean, 1.0).expand([out_features, in_features]).to_event(2)
        )
        if bias:
            self.bias = PyroSample(
                dist.Normal(self.bias_prior_mean, 1.0).expand([out_features]).to_event(1)
            )

    def forward(self, x, y=None):
        sigma = pyro.sample('sigma', dist.Uniform(self.sigma_prior_mean, 1.0))
        mean = super().forward(x)
        with pyro.plate('data', size=x.shape[0], dim=-2):
            obs = pyro.sample('obs', dist.Normal(mean, sigma), obs=y)
        return mean

if __name__ == '__main__':
    model = BayesianLinear(10, 20)
    model = model.to(device='cuda:0')
    data = torch.randn(4, 10).to(device='cuda:0')
    out = model(data) # This throws a device mismatch