Conditional Spline Poor Performance

Hi,

I have implemented a neural network with a convolutional layer and a Conditional Spline transformation. The idea being that instead of predicting a value, it predicts the transformation that transforms a normal distribution centered around 0 and scale 2 into a distribution representing the target.
Below is the code of my implementation. However, the neural net has poor performance and I was wondering if I am implementing correctly and if the optimizer is also optimizing the parameters of the conditional transformation. Please could you check this short code to see if I am using it correctly ?

import pytorch_lightning as pl
import pyro.distributions as dist
from pyro.nn.dense_nn import DenseNN
from pyro.distributions.transforms import ConditionalSpline

class TCN(pl.LightningModule):
    def __init__(self,n_channels=128,k_size=3,regressor_size=200):
        super().__init__()
        
        self.n_channels, self.k_size, self.regressor_size = n_channels, k_size, regressor_size

        self.features = nn.Sequential(
            nn.Conv1d(in_channels=1, 
                      out_channels=self.n_channels, 
                      kernel_size=self.k_size, 
                      padding=(int(self.k_size/2))),
            nn.MaxPool1d(2),
            nn.ReLU(),
        )

        # Inspired by the Conditional Spline Example in pyro
        self.input_dim = 1
        self.context_dim = self.n_channels*int(seq_length/2)
        count_bins = 16
        
        param_dims = [self.input_dim * count_bins, self.input_dim * count_bins,
                      self.input_dim * (count_bins - 1), self.input_dim * count_bins]
        self.hypernet = DenseNN(self.context_dim, [20, 20], param_dims)
        self.transform = ConditionalSpline(self.hypernet, self.input_dim, count_bins)
    
    def forward(self,x):
        x = self.features(x)
        x = x.view(-1,self.n_channels*x.shape[2])
        base_dist = dist.Normal(torch.zeros(self.input_dim).to(self.device)  , torch.ones(self.input_dim).to(self.device)  )
        flow_dist = dist.ConditionalTransformedDistribution(base_dist,[self.transform]).condition(x)
        return flow_dist
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        dist = self(x)
        loss = -dist.log_prob(y).mean()
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch  
        dist = self(x)
        loss = -dist.log_prob(y).mean()
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        dist = self(x)
        loss = -dist.log_prob(y).mean()
        self.log("test_loss", loss)
        
    def predict_step(self, batch, batch_idx):
        x, y = batch
        dist = self(x)
        return dist.sample(torch.tensor([10000,len(x)])).T[0]
    
    def configure_optimizers(self):
        optimizer = COCOBBackprop(self.parameters())
        return optimizer

Thanks,
Sebas