PyroSample and cuda GPU


I’m experimenting with training a simple Bayesian feedforward network on the GPU. The model trains on the CPU just fine, but when I try training on the GPU, it gives an error about tensors being on more than one device.

I’ve been able to track down the cause of the error being the PyroSample statements that replace the nn.Linear layer’s parameters with Pyro parameters. When I comment out these statements, the model runs fine on both the CPU and GPU. But when I include the statements, the model runs into issues because I believe the self.cuda() I’m using is not putting the Pyro parameters on the GPU.

So, my question is, what is the best way for putting the PyroSample weight/bias parameters (which override the PyTorch parameters of the Linear layer) on the GPU? I pasted my simple code below. It seems like there is an easy solution but my self.cuda() doesn’t seem to work.

# Put data on GPU.
X, y = X.cuda(), y.cuda()

# Specify model.
class TestNN(PyroModule):
    def __init__(self, in_features, out_features=1):
        self.fc1 = PyroModule[nn.Linear](in_features, out_features)
        # Replace network layer parameters with Pyro priors.
        self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))  # -->>> how to put these parameters on the GPU??
        self.fc1.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))  # -->>> how to put these parameters on the GPU??
    def forward(self, x, y=None):
        p = torch.sigmoid(self.fc1(x))
        # Likelihood.
        with pyro.plate('data', x.shape[0]):
            obs = pyro.sample('obs', dist.Bernoulli(p), obs=y)
        return p

# The below also doesn't seem to put the parameters on the GPU.
test_nn = TestNN(in_features=X.shape[1])
test_nn.cuda()  # doesn't work??

floats don’t specify devices so you probably need to be more explicit e.g.

PyroSample(dist.Normal(0., torch.tensor(1.0, device=X.device)).expand([out_features, in_features]).to_event(2))

Ah, it looks like that did fix it by setting one of the parameters to a torch.tensor with the device explicitly stated. Thanks!

I tried with the mean parameter and then the scale parameter, and it worked in both cases. Not sure if there is a preferred way of whether I specify the mean or scale with torch.tensor(device), but either way it seems that only one of the parameters needs that instead of both for some reason.