A complex Bayesian CNN regression model

Hello,
I am trying to use the variational method (pyro.infer.SVI) for my project to infer a single numbers (and variational errors on it) given a 2D data matrix. I consider weights and the biases of the network to be the latent parameters which are supposed to be varied stochastically. Somehow I manage to write a code (the soul part of which is following) which is running and training. It is computing ELBO and the value is going down to 4 order of magnitude from the start (within 400 epochs). Although when I am trying to predict the values using the trained network. It is returning the samples dictionary empty. Am I doing something technically wrong here? Or is this the issue with the sensitivity of the initial random choice of weights and biases? The code snippet is as follows.

I have built this particular architecture from a CNN which is performing great and producing nice deterministic results.

class MyBNN(PyroModule):
    #------------- instantiating the model below -------------#
    def __init__(self):
        super(MyBNN, self).__init__() # adopting from the self parent class

        #---------- adding the convolutional and pooling layers -----------#
        self.conv1 = PyroModule[nn.Conv2d](in_channels=1, out_channels=32, kernel_size=3
                                        , stride=1, padding='valid', bias=True) # valid padding to stop any padding
        self.conv1.weight = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.conv1.weight.size()).to_event(self.conv1.weight.dim()))
        self.conv1.bias = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.conv1.bias.size()).to_event(self.conv1.bias.dim()))                          

        self.conv2 = PyroModule[nn.Conv2d](in_channels=32, out_channels=8, kernel_size=3
                                        , stride=1, padding='valid', bias=True) # valid padding to stop any padding
        self.conv2.weight = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.conv2.weight.size()).to_event(self.conv2.weight.dim()))
        self.conv2.bias = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.conv2.bias.size()).to_event(self.conv2.bias.dim()))
        
        #---------- adding the linear layers for the dense NN -----------#

        flat_size = 126*5*8
        # 128 and 5 is coming from the image size, 8 is the number of channels in last conv layer

        self.lin1 = PyroModule[nn.Linear](in_features=flat_size, out_features=16,bias=True)
        self.lin1.weight = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.lin1.weight.size()).to_event(self.lin1.weight.dim())) 
        self.lin1.bias = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.lin1.bias.size()).to_event(self.lin1.bias.dim())) 
        
        
        self.lin2 = PyroModule[nn.Linear](in_features=16, out_features=4,bias=True)
        self.lin2.weight = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.lin2.weight.size()).to_event(self.lin2.weight.dim()))
        self.lin2.bias = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.lin2.bias.size()).to_event(self.lin2.bias.dim())) 
        
        self.lin3 = PyroModule[nn.Linear](in_features=4, out_features=1,bias=True)
        self.lin3.weight = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.lin3.weight.size()).to_event(self.lin3.weight.dim()))
        self.lin3.bias = PyroSample(dist.Normal(torch.tensor(0.,device="cuda"),1.).expand(self.lin3.bias.size()).to_event(self.lin3.bias.dim()))
        
    
    #------------- Now arranging the layers for forwarding the model -------------#

    def forward(self, x, y=None):
        x = x.view(x.size(0), 1, x.size(1),x.size(2)) # reshaping x in form of (Nsamp, Chan, Height, Width)
        x = F.max_pool2d(self.conv1(x), kernel_size=2, ceil_mode=False)
        x = F.relu(x)

        x = F.avg_pool2d(self.conv2(x), kernel_size=2, ceil_mode=False)
        x = F.relu(x)

        x = x.view(x.size(0), -1) # flattening the outputs of the last convolutional layer
        
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x) # no rectifier as this is the output

        return x # returning the output tensor from here

##################################################################

lrate = 1.e-3 # learning rate of the Adam optimizer
Nepoch = 1000

mymodel = MyBNN().to(device) # this instantiate the model class

guide = AutoDiagonalNormal(mymodel)
optm = pyro.optim.Adam({"lr":lrate})
svi = SVI(mymodel, guide, optm, loss=Trace_ELBO())

##################################################################

# Splitting into training and test set
x_train, x_val, y_train, y_val = train_test_split(all_spec, all_tau, train_size=0.8, test_size=0.2, random_state=10)

x_train = torch.tensor(x_train, dtype=torch.float32)
x_val = torch.tensor(x_val, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.float32)

x_train = x_train.to(device)
x_val = x_val.to(device)
y_train = y_train.to(device)
y_val = y_val.to(device)

train_ds = TensorDataset(x_train, y_train) # setting training dataset tensor
train_dl = DataLoader(train_ds, batch_size=50, shuffle=True) # this is the loader api which pass the dataset at every epoch

valid_ds = TensorDataset(x_val, y_val) # setting validation dataset tensor
valid_dl = DataLoader(valid_ds, batch_size=100) # this is the dataloader for validation set

##################################################################
pyro.clear_param_store()

for i in tqdm(range(0, Nepoch), desc="Epochs", unit=" Epochs"):
    # caluating losses and taking a gradient step
    train_loss = 0.
    nums = 0
    for xd, yd in train_dl:
        loss = svi.step(xd, yd)
        train_loss += loss
        nums += len(yd)
    
    train_loss /=nums

    val_loss = 0.
    nums = 0
    for xd, yd in valid_dl:
        val_loss += svi.evaluate_loss(xd, yd)
        nums += len(yd)
    
    val_loss /=nums
    print(i, train_loss, val_loss)


def predict(model, x_val, y_val, num_samples=1000):
    predictive_distribution = pyro.infer.Predictive(model, guide=guide, num_samples=num_samples, return_sites=("obs",))
    samples = predictive_distribution(x_val, y_val)
    y_preds = samples["obs"]
    return y_preds

pred = predict(mymodel, x_val, y_val, num_samples=1000)

Thanks in advance!