PyroSample and cuda GPU

Hi,

I’m experimenting with training a simple Bayesian feedforward network on the GPU. The model trains on the CPU just fine, but when I try training on the GPU, it gives an error about tensors being on more than one device.

I’ve been able to track down the cause of the error being the PyroSample statements that replace the nn.Linear layer’s parameters with Pyro parameters. When I comment out these statements, the model runs fine on both the CPU and GPU. But when I include the statements, the model runs into issues because I believe the self.cuda() I’m using is not putting the Pyro parameters on the GPU.

So, my question is, what is the best way for putting the PyroSample weight/bias parameters (which override the PyTorch parameters of the Linear layer) on the GPU? I pasted my simple code below. It seems like there is an easy solution but my self.cuda() doesn’t seem to work.

# Put data on GPU.
X, y = X.cuda(), y.cuda()

# Specify model.
class TestNN(PyroModule):
    def __init__(self, in_features, out_features=1):
        super().__init__()
        self.fc1 = PyroModule[nn.Linear](in_features, out_features)
        # Replace network layer parameters with Pyro priors.
        self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))  # -->>> how to put these parameters on the GPU??
        self.fc1.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))  # -->>> how to put these parameters on the GPU??
        self.cuda()
        
    def forward(self, x, y=None):
        p = torch.sigmoid(self.fc1(x))
        # Likelihood.
        with pyro.plate('data', x.shape[0]):
            obs = pyro.sample('obs', dist.Bernoulli(p), obs=y)
        return p

# The below also doesn't seem to put the parameters on the GPU.
test_nn = TestNN(in_features=X.shape[1])
test_nn.cuda()  # doesn't work??

1 Like

floats don’t specify devices so you probably need to be more explicit e.g.

PyroSample(dist.Normal(0., torch.tensor(1.0, device=X.device)).expand([out_features, in_features]).to_event(2))
2 Likes

Ah, it looks like that did fix it by setting one of the parameters to a torch.tensor with the device explicitly stated. Thanks!

I tried with the mean parameter and then the scale parameter, and it worked in both cases. Not sure if there is a preferred way of whether I specify the mean or scale with torch.tensor(device), but either way it seems that only one of the parameters needs that instead of both for some reason.

Hi fellows, by any chance if you could help my on transferring the pyrosample to GPU? I juts found this post, but yet could not fixed on my pyro model. I put here a simple example juts to make sure how transfer it to gpu. Now the model running but it seems that not using GPU! howvere when i use pytorch without adding pyrosamples it sued GPU

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

class MyFirstBNN(PyroModule):
def init(self, in_dim=1, out_dim=1, hid_dim=14, prior_scale=10.):
super().init()

    self.activation = nn.ReLU()  # or nn.ReLU()
    self.layer1 = PyroModule[nn.Linear](in_dim, hid_dim)  # Input to hidden layer
    self.layer2 = PyroModule[nn.Linear](hid_dim, out_dim)  # Hidden to output layer
    # Move the model to the appropriate device
    self.to(device)
    # Set layer parameters as random variables

    self.layer1.weight = PyroSample(dist.Normal(0., torch.tensor(prior_scale, device=device)).expand([hid_dim, in_dim]).to_event(2))
    self.layer1.bias = PyroSample(dist.Normal(0., torch.tensor(prior_scale, device=device)).expand([hid_dim]).to_event(1))
    self.layer2.weight = PyroSample(dist.Normal(0., torch.tensor(prior_scale, device=device)).expand([out_dim, hid_dim]).to_event(2))
    self.layer2.bias = PyroSample(dist.Normal(torch.zeros([out_dim], device=device), prior_scale).to_event(1))


def forward(self, x, y=None):
    x = x.reshape(-1, 1).to(device)
    x = self.activation(self.layer1(x))
    mu = self.layer2(x).squeeze()
    sigma = pyro.sample("sigma", dist.Gamma(torch.tensor(.1, device=device), torch.tensor(.2, device=device)))

    # Sampling model
    y = y.to(device)
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.Normal(mu, sigma * sigma), obs=y)
    return mu

model = MyFirstBNN(in_dim=1, out_dim=1, hid_dim=800, prior_scale=10.).to(device)

pyro.set_rng_seed(42)
nuts_kernel = NUTS(model, jit_compile=True) # jit_compile=True is faster but requires PyTorch 1.6+

mcmc = MCMC(nuts_kernel, num_samples=100)

x_train = x_train.to(device) # Move input data to the GPU
y_train = y_train.to(device) # Move target data to the GPU
mcmc.run(x_train, y_train)