Validation split of 80%/20% not working using a model with SVI

We try to train the following model and print out the validation loss during training. Unfortunately we were not able to get this to work unless we have a training and a test-dataset of exactly the same size. We would optimally want to aim for an 80%/20% test-train-split. Is there a smart way to do this. Does anyone maybe have a code example? :pray:t3:

Our Model:

class FA(PyroModule):
    def __init__(self, train_data, test_data, n_features1, n_features2, K):
        """
        Args:
            Y: Tensor (Samples x Features)
            K: Number of Latent Factors
        """
        super().__init__()
        pyro.clear_param_store()
        
        # data
        self.num_features1 = n_features1
        self.num_features2 = n_features2
        self.Y1 = train_data[:,:n_features1]
        self.Y2 = train_data[:,n_features1:]
        self.K = K
        self.train_data = train_data
        self.test_data = test_data

        self.test_Y1 = test_data[:,:n_features1]
        self.test_Y2 = test_data[:,n_features1:]
        
        self.num_samples = self.Y1.shape[0]
        
        self.sample_plate = pyro.plate("sample", self.num_samples)
        self.feature_plate1 = pyro.plate("feature1", self.num_features1)
        self.feature_plate2 = pyro.plate("feature2", self.num_features2)
        self.latent_factor_plate = pyro.plate("latent factors", self.K)

        print(self.test_Y1.shape)
        print(self.test_Y2.shape)
        
    def model(self, Y1, Y2):
        """
        how to generate a matrix
        """
        with self.latent_factor_plate:
            with self.feature_plate1:
                # sample weight matrix with Normal prior distribution
                W1 = pyro.sample("W1", pyro.distributions.Normal(0., 1.))  

            with self.feature_plate2:
                # sample weight matrix with Normal prior distribution
                W2 = pyro.sample("W2", pyro.distributions.Normal(0., 1.))               
                
            with self.sample_plate:
                # sample factor matrix with Normal prior distribution
                Z = pyro.sample("Z", pyro.distributions.Normal(0., 1.))
        
        # estimate for Y
        Y1_hat = torch.matmul(Z, W1.t())
        Y2_hat = torch.matmul(Z, W2.t())
        
        with pyro.plate("feature1_", Y1.shape[1]), pyro.plate("sample_", Y1.shape[0]):
            # masking the NA values such that they are not considered in the distributions
            obs_mask = torch.ones_like(Y1, dtype=torch.bool)
            if data is not None:
                obs_mask = torch.logical_not(torch.isnan(Y1))
            with pyro.poutine.mask(mask=obs_mask):
                if data is not None:
                    # a valid value for the NAs has to be defined even though these samples will be ignored later
                    Y1 = torch.nan_to_num(Y1, nan=0) 
            
                    # sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
                    scale = pyro.sample("scale", pyro.distributions.LogNormal(0., 1.))
                    # compare sampled estimation to the true observation Y
                    pyro.sample("obs1", pyro.distributions.Normal(Y1_hat, scale), obs=Y1)


        with pyro.plate("feature2_", Y2.shape[1]), pyro.plate("sample2_", Y2.shape[0]):
            # masking the NA values such that they are not considered in the distributions
            obs_mask = torch.ones_like(Y2, dtype=torch.bool)
            if data is not None:
                obs_mask = torch.logical_not(torch.isnan(Y2))
            with pyro.poutine.mask(mask=obs_mask):
                if data is not None:
                    # a valid value for the NAs has to be defined even though these samples will be ignored later
                    Y2 = torch.nan_to_num(Y2, nan=0) 
            
                    # sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
                    scale = pyro.sample("scale2", pyro.distributions.LogNormal(0., 1.))
                    # compare sampled estimation to the true observation Y
                    pyro.sample("obs2", pyro.distributions.Normal(Y2_hat, scale), obs=Y2)

    def train(self):
        # set training parameters
        optimizer = pyro.optim.Adam({"lr": 0.02})
        elbo = Trace_ELBO()
        guide = autoguide.AutoDelta(self.model)
        
        # initialize stochastic variational inference
        svi = SVI(
            model = self.model,
            guide = guide,
            optim = optimizer,
            loss = elbo
        )
        
        num_iterations = 2000
        train_loss = []
        test_loss = []
        for j in range(num_iterations):
        #for j in enumerate(self.train_dataloader):
            # calculate the loss and take a gradient step
            loss = svi.step(self.Y1.T, self.Y2)

            train_loss.append(loss/self.Y1.shape[0])
            #    test_loss.append(elbo.loss(self.model, guide, test_data))
            if j % 200 == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, loss / self.Y1.shape[0]))
            
            with torch.no_grad():  # for logging only
                train_loss2 = elbo.loss(self.model, guide, self.Y1, self.Y2) # or average over batch_loss
                test_loss = elbo.loss(self.model, guide, self.test_Y1, self.test_Y2)
            print(train_loss2, test_loss)
        
        # Obtain maximum a posteriori estimates for W and Z
        map_estimates = guide(Y)
        
        return train_loss, map_estimates

Error we get:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[70], line 2
      1 FA_model = FA(train_data, test_data,20, 30, 5)
----> 2 losses, estimates = FA_model.train()

Cell In[69], line 114
    112     with torch.no_grad():  # for logging only
    113         train_loss2 = elbo.loss(self.model, guide, self.Y1, self.Y2) # or average over batch_loss
--> 114         test_loss = elbo.loss(self.model, guide, self.test_Y1, self.test_Y2)
    115     print(train_loss2, test_loss)
    117 # Obtain maximum a posteriori estimates for W and Z

File ~/.local/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:72, in Trace_ELBO.loss(self, model, guide, *args, **kwargs)
     65 """
     66 :returns: returns an estimate of the ELBO
     67 :rtype: float
     68 
     69 Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
     70 """
     71 elbo = 0.0
---> 72 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
     73     elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item(
     74         guide_trace.log_prob_sum()
     75     )
     76     elbo += elbo_particle / self.num_particles

File ~/.local/lib/python3.8/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs)
    235 else:
    236     for i in range(self.num_particles):
--> 237         yield self._get_trace(model, guide, args, kwargs)

File ~/.local/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File ~/.local/lib/python3.8/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     72 guide_trace = prune_subsample_sites(guide_trace)
     73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
     76 guide_trace.compute_score_parts()
     77 if is_validation_enabled():

File ~/.local/lib/python3.8/site-packages/pyro/poutine/trace_struct.py:276, in Trace.compute_log_prob(self, site_filter)
    270     raise ValueError(
    271         "Error while computing log_prob at site '{}':\n{}\n{}".format(
    272             name, exc_value, shapes
    273         )
    274     ).with_traceback(traceback) from e
    275 site["unscaled_log_prob"] = log_p
--> 276 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
    277 site["log_prob"] = log_p
    278 site["log_prob_sum"] = log_p.sum()

File ~/.local/lib/python3.8/site-packages/pyro/distributions/util.py:328, in scale_and_mask(tensor, scale, mask)
    326 if mask is False:
    327     return torch.zeros_like(tensor)
--> 328 return torch.where(mask, tensor * scale, tensor.new_zeros(()))

RuntimeError: The size of tensor a (80) must match the size of tensor b (20) at non-singleton dimension 0

Thank you, any help is greatly appreciated. :pray:t3::sparkles:

you have local latent variables z that are “private” to each data point. so at test/validation time you also need to do inference as you’re faced with a new inference task (inferring the posterior over z for newly introduced data points). unless you use so-called amortized inference there is no cheap shortcut to avoiding the new inference task, and in the context of variational inference new inference task means new optimization task.

1 Like

Thank you for your help. :smiling_face: :pray:

Are you proposing all of the training first and then all of the validation with a separate model afterwards. Just asking whether I understand you correctly.

Is there no way of getting per-iteration training data, as described in the following blog post: https://forum.pyro.ai/t/how-to-get-validation-loss-in-svi/4037

I tried learning more about amortized inference. Do you have a beginner-friendly tutorial that is out there for amortized inference using pyro?

the concept of amortization is briefly described here: SVI Part II: Conditional Independence, Subsampling, and Amortization — Pyro Tutorials 1.9.0 documentation

one concrete example can be found in the VAE: Variational Autoencoders — Pyro Tutorials 1.9.0 documentation

without amortization it’ll probably be computationally inctractable to get a useful “validation loss”

1 Like

Thank you already for the fast and incredible help. :smiling_face: :pray:

May I also ask how you evaluate/compare your statistical models if you do not use validation?

Sorry for the amount of questions. Moving from classical Machine Learning to Probabilistic Programming requires a little time.

in the your case I would do validation once at the end. to do that I’d learn new approximate posteriors for the newly introduced locals while freezing the approximate posteriors for the globals to the fits obtained from the training data

1 Like

Sorry, for so many questions. What counts as the global posteriors, in our given case? :sweat_smile:

the opposite of local is global