Question about qyro tensors and cuda

see, i am new to pyro and i built a network, but it says

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

in code, seem like the to(device) function did not work on the pyro layers and i don`t know why

from pyro.optim import Adam

def train(svi, train_loader):
    epoch_loss = 0.
    for x, y in train_loader:
      x, y =,
      epoch_loss += svi.step(x, y) # where the problem showed
    return epoch_loss / len(train_loader)

model = BayesianNeuralNetwork(input_channels=1, hidden_size=256, output_size=10).to(device)

guide = pyro.infer.autoguide.AutoDiagonalNormal(model).to(device)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())

num_epochs = 50
for epoch in range(num_epochs):
    loss = train(svi, train_loader)
    print(f"Epoch {epoch+1}, Loss: {loss}")

# the network below

import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from torch import nn

class BayesianNeuralNetwork(PyroModule):
    def __init__(self, input_channels, hidden_size, output_size, kernel_size=5):
        self.conv1 = nn.Conv2d(input_channels, 20, kernel_size=kernel_size)

        conv_output_size = ((28 - kernel_size + 1) ** 2) * 20
        self.fc1 = PyroModule[nn.Linear](conv_output_size, hidden_size)
        self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([hidden_size, conv_output_size]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(0., 10.).expand([hidden_size]).to_event(1))
        self.fc2 = PyroModule[nn.Linear](hidden_size, output_size)
        self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([output_size, hidden_size]).to_event(2))
        self.fc2.bias = PyroSample(dist.Normal(0., 10.).expand([output_size]).to_event(1))

    def forward(self, x, y=None):
        x = torch.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # Flatten
        print(x.device)                         #at cuda
        print(self.fc1.weight.device)    #at cpu
        print(self.fc1.bias.device)    # at cpu
        x = torch.relu(self.fc1(x))    #showed error here
        logits = self.fc2(x)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return logits