Problem saving model, weakref

I have this fairly straight forward VAE that I wish to train and save

class Encoder(nn.Module):
    """ Simple MLP encoder."""
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # mean and log variance. of the latent
        ).to(device)

    def forward(self, x):
        h = self.encoder(x)
        z_mu, z_logvar = torch.chunk(h, 2, dim=-1)
        return z_mu, z_logvar

class Decoder(nn.Module):
    """
    Simple MLP decoder
    """
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.cuda()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim*2), # mean and log variance. of the MVN
        ).to(device)

    def forward(self, z):
        x = self.decoder(z)
        x_mu, x_logvar = torch.chunk(x, 2, dim=-1)
        return x_mu, x_logvar

class XYEncoder(nn.Module):
    """ Simple MLP encoder """
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(XYEncoder, self).__init__()
        self.cuda()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # mean and log variance. of the latent
        ).to(device)

    def forward(self, x,y):
        l = torch.cat((x,y), dim=1)
        h = self.encoder(l)
        z_mu, z_logvar = torch.chunk(h, 2, dim=-1)
        return z_mu, z_logvar


class VAE(nn.Module):

    def __init__(self,input_dim, z_dim=50, hidden_dim=400, beta=1.0, device="cuda"):
        super().__init__()
        if device=='cuda':
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        # create the encoder and decoder networks
        self.encoder = Encoder(input_dim, hidden_dim, z_dim)
        self.decoder = Decoder(z_dim, hidden_dim, input_dim)

        self.device = device
        self.z_dim = z_dim
        self.beta = beta

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        batch_size = x.shape[0]
        with pyro.plate("data", batch_size):
            # setup hyperparameters for prior p(z)
            z_loc = x.new_zeros(torch.Size((batch_size, self.z_dim)))
            z_scale = x.new_ones(torch.Size((batch_size, self.z_dim)))
            with pyro.poutine.scale(scale=self.beta):
                # sample from prior (value will be sampled by guide when computing the ELBO)
                z = pyro.sample("latent", dist.Normal(z_loc, torch.exp(z_scale)).to_event(1))
            # decode the latent code z
            x_mu, x_logvar = self.decoder(z)
            # score against actual images
            pyro.sample("obs", dist.Normal(x_mu,torch.exp(x_logvar)).to_event(1), obs=x)

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        batch_size = x.shape[0]
        with pyro.plate("data", batch_size):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, torch.exp(z_scale)).to_event(1))

    def predict(self, maldi_data, density, sample="mean"):

        z_loc, z_scale = self.encoder(maldi_data)
        z = dist.Normal(z_loc, torch.exp(z_scale)).sample()
        x_mu, x_logvar = self.decoder(z)
        if sample == "mean":
            return x_mu
        else:
            return dist.Normal(x_mu, torch.exp(x_logvar)).sample(

Wen I wish to do torch.save I get error due to weakref, any idea on where is the weakref?

optimizer = torch.optim.Adam(vae.parameters())
        # train the model
        logging.info("Training the VAE model")
        print(model_name)
        optimizer = Adam({"lr": 0.001})
        pyro.get_param_store().clear()
        svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        for epoch in range(epochs):
            total_loss = 0.0
            for idx, (data, density) in enumerate(data_loader):
                x = data.to(vae.device)
                y = density.to(vae.device)
                loss = svi.step(x)
                wandb.log({"loss": loss})
                total_loss += loss
            print("Epoch ", epoch, " Loss ", total_loss)
        logging.info("Saving the model")
        torch.save(vae, str(model_file))

Any help would be highly appreciated, thank you and best regards

Try saving pyro.get_param_store().get_state().save