see, i am new to pyro and i built a network, but it says
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
in code, seem like the to(device) function did not work on the pyro layers and i don`t know why
from pyro.optim import Adam
def train(svi, train_loader):
epoch_loss = 0.
for x, y in train_loader:
x, y = x.to(device), y.to(device)
epoch_loss += svi.step(x, y) # where the problem showed
return epoch_loss / len(train_loader)
model = BayesianNeuralNetwork(input_channels=1, hidden_size=256, output_size=10).to(device)
guide = pyro.infer.autoguide.AutoDiagonalNormal(model).to(device)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())
num_epochs = 50
for epoch in range(num_epochs):
loss = train(svi, train_loader)
print(f"Epoch {epoch+1}, Loss: {loss}")
# the network below
import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from torch import nn
class BayesianNeuralNetwork(PyroModule):
def __init__(self, input_channels, hidden_size, output_size, kernel_size=5):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, 20, kernel_size=kernel_size)
conv_output_size = ((28 - kernel_size + 1) ** 2) * 20
self.fc1 = PyroModule[nn.Linear](conv_output_size, hidden_size)
self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([hidden_size, conv_output_size]).to_event(2))
self.fc1.bias = PyroSample(dist.Normal(0., 10.).expand([hidden_size]).to_event(1))
self.fc2 = PyroModule[nn.Linear](hidden_size, output_size)
self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([output_size, hidden_size]).to_event(2))
self.fc2.bias = PyroSample(dist.Normal(0., 10.).expand([output_size]).to_event(1))
def forward(self, x, y=None):
x = torch.relu(self.conv1(x))
x = x.view(x.size(0), -1) # Flatten
print(x.device) #at cuda
print(self.fc1.weight.device) #at cpu
print(self.fc1.bias.device) # at cpu
x = torch.relu(self.fc1(x)) #showed error here
logits = self.fc2(x)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
return logits