Validation split of 80%/20% not working using a model with SVI

We try to train the following model and print out the validation loss during training. Unfortunately we were not able to get this to work unless we have a training and a test-dataset of exactly the same size. We would optimally want to aim for an 80%/20% test-train-split. Is there a smart way to do this. Does anyone maybe have a code example? :pray:t3:

Our Model:

class FA(PyroModule):
    def __init__(self, train_data, test_data, n_features1, n_features2, K):
            Y: Tensor (Samples x Features)
            K: Number of Latent Factors
        # data
        self.num_features1 = n_features1
        self.num_features2 = n_features2
        self.Y1 = train_data[:,:n_features1]
        self.Y2 = train_data[:,n_features1:]
        self.K = K
        self.train_data = train_data
        self.test_data = test_data

        self.test_Y1 = test_data[:,:n_features1]
        self.test_Y2 = test_data[:,n_features1:]
        self.num_samples = self.Y1.shape[0]
        self.sample_plate = pyro.plate("sample", self.num_samples)
        self.feature_plate1 = pyro.plate("feature1", self.num_features1)
        self.feature_plate2 = pyro.plate("feature2", self.num_features2)
        self.latent_factor_plate = pyro.plate("latent factors", self.K)

    def model(self, Y1, Y2):
        how to generate a matrix
        with self.latent_factor_plate:
            with self.feature_plate1:
                # sample weight matrix with Normal prior distribution
                W1 = pyro.sample("W1", pyro.distributions.Normal(0., 1.))  

            with self.feature_plate2:
                # sample weight matrix with Normal prior distribution
                W2 = pyro.sample("W2", pyro.distributions.Normal(0., 1.))               
            with self.sample_plate:
                # sample factor matrix with Normal prior distribution
                Z = pyro.sample("Z", pyro.distributions.Normal(0., 1.))
        # estimate for Y
        Y1_hat = torch.matmul(Z, W1.t())
        Y2_hat = torch.matmul(Z, W2.t())
        with pyro.plate("feature1_", Y1.shape[1]), pyro.plate("sample_", Y1.shape[0]):
            # masking the NA values such that they are not considered in the distributions
            obs_mask = torch.ones_like(Y1, dtype=torch.bool)
            if data is not None:
                obs_mask = torch.logical_not(torch.isnan(Y1))
            with pyro.poutine.mask(mask=obs_mask):
                if data is not None:
                    # a valid value for the NAs has to be defined even though these samples will be ignored later
                    Y1 = torch.nan_to_num(Y1, nan=0) 
                    # sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
                    scale = pyro.sample("scale", pyro.distributions.LogNormal(0., 1.))
                    # compare sampled estimation to the true observation Y
                    pyro.sample("obs1", pyro.distributions.Normal(Y1_hat, scale), obs=Y1)

        with pyro.plate("feature2_", Y2.shape[1]), pyro.plate("sample2_", Y2.shape[0]):
            # masking the NA values such that they are not considered in the distributions
            obs_mask = torch.ones_like(Y2, dtype=torch.bool)
            if data is not None:
                obs_mask = torch.logical_not(torch.isnan(Y2))
            with pyro.poutine.mask(mask=obs_mask):
                if data is not None:
                    # a valid value for the NAs has to be defined even though these samples will be ignored later
                    Y2 = torch.nan_to_num(Y2, nan=0) 
                    # sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
                    scale = pyro.sample("scale2", pyro.distributions.LogNormal(0., 1.))
                    # compare sampled estimation to the true observation Y
                    pyro.sample("obs2", pyro.distributions.Normal(Y2_hat, scale), obs=Y2)

    def train(self):
        # set training parameters
        optimizer = pyro.optim.Adam({"lr": 0.02})
        elbo = Trace_ELBO()
        guide = autoguide.AutoDelta(self.model)
        # initialize stochastic variational inference
        svi = SVI(
            model = self.model,
            guide = guide,
            optim = optimizer,
            loss = elbo
        num_iterations = 2000
        train_loss = []
        test_loss = []
        for j in range(num_iterations):
        #for j in enumerate(self.train_dataloader):
            # calculate the loss and take a gradient step
            loss = svi.step(self.Y1.T, self.Y2)

            #    test_loss.append(elbo.loss(self.model, guide, test_data))
            if j % 200 == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, loss / self.Y1.shape[0]))
            with torch.no_grad():  # for logging only
                train_loss2 = elbo.loss(self.model, guide, self.Y1, self.Y2) # or average over batch_loss
                test_loss = elbo.loss(self.model, guide, self.test_Y1, self.test_Y2)
            print(train_loss2, test_loss)
        # Obtain maximum a posteriori estimates for W and Z
        map_estimates = guide(Y)
        return train_loss, map_estimates

Error we get:

Thank you, any help is greatly appreciated. :pray:t3::sparkles:

you have local latent variables z that are “private” to each data point. so at test/validation time you also need to do inference as you’re faced with a new inference task (inferring the posterior over z for newly introduced data points). unless you use so-called amortized inference there is no cheap shortcut to avoiding the new inference task, and in the context of variational inference new inference task means new optimization task.

Thank you for your help. :smiling_face: :pray:

Are you proposing all of the training first and then all of the validation with a separate model afterwards. Just asking whether I understand you correctly.

Is there no way of getting per-iteration training data, as described in the following blog post:

I tried learning more about amortized inference. Do you have a beginner-friendly tutorial that is out there for amortized inference using pyro?

the concept of amortization is briefly described here: SVI Part II: Conditional Independence, Subsampling, and Amortization — Pyro Tutorials 1.9.0 documentation

one concrete example can be found in the VAE: Variational Autoencoders — Pyro Tutorials 1.9.0 documentation

without amortization it’ll probably be computationally inctractable to get a useful “validation loss”

Thank you already for the fast and incredible help. :smiling_face: :pray:

May I also ask how you evaluate/compare your statistical models if you do not use validation?

Sorry for the amount of questions. Moving from classical Machine Learning to Probabilistic Programming requires a little time.

in the your case I would do validation once at the end. to do that I’d learn new approximate posteriors for the newly introduced locals while freezing the approximate posteriors for the globals to the fits obtained from the training data

1 Like

Sorry, for so many questions. What counts as the global posteriors, in our given case? :sweat_smile:

the opposite of local is global