SVI Validation loss

Hello,

I’m a new user to Pyro and wondering how may I obtain validation loss, abs(prediction-true_value), from SVI? I tried to use ‘prediction_mean’ from model function but failed.

Code is attached for reference

    def model(self, x, y=None):
        # priors for network weights and biases
        priors = {}
        for name, param in self.network.named_parameters():
            if 'weight' in name:
                if self.zero_prior: 
                    priors[name] = Normal(torch.zeros_like(param), torch.ones_like(param)).to_event(2)
                else: 
                    priors[name] = Normal(param, torch.ones_like(param)).to_event(2)
            elif 'bias' in name:
                if self.zero_prior: 
                    priors[name] = Normal(torch.zeros_like(param), torch.ones_like(param)).to_event(1)
                else: 
                    priors[name] = Normal(param, torch.ones_like(param)).to_event(1)    

        # lift module parameters to random variables sampled from the priors
        lifted_module = pyro.random_module('module', self.network, priors)
        # sample a regressor
        lifted_reg_model = lifted_module()
        
        with pyro.plate('map', len(x)):
            # run the regressor forward on data
            prediction_mean = lifted_reg_model(x).squeeze(-1)
            # condition on the observed data
            if y is not None: 
                y = y.squeeze(-1)
            pyro.sample('y', Normal(prediction_mean, self.noise_scale), obs=y)

        return prediction_mean
    def guide(self, x, y=None):
        softplus  = nn.Softplus()

        # specify guide distributions for network weights and biases
        dists = {}
        for name, param in self.network.named_parameters():
            if 'weight' in name:
                loc = pyro.param('guide_loc_{}'.format(name), torch.randn_like(param))
                scale = pyro.param('guide_scale_{}'.format(name), softplus(torch.randn_like(param)), 
                                   constraint=constraints.positive)
                dists[name] = Normal(loc, scale).to_event(2)
            elif 'bias' in name:
                loc = pyro.param('guide_loc_{}'.format(name), torch.randn_like(param))
                scale = pyro.param('guide_scale_{}'.format(name), softplus(torch.randn_like(param)), 
                                   constraint=constraints.positive)
                dists[name] = Normal(loc, scale).to_event(1)
        
        # overload the parameters in the module with random samples from the guide distributions
        lifted_module = pyro.random_module('module', self.network, dists)
        # sample a regressor
        lifted_reg_model = lifted_module()
    ### SVI ###
    def infer_svi(self, x, y, num_epochs=1000, lr=0.005, print_freq=1000, batch_size=10):
        # create dataloader 
        dataloader = DataLoader(TensorDataset(x, y), batch_size=batch_size, shuffle=True)
        
        # inference
        pyro.clear_param_store()
        # guide = AutoDiagonalNormal(self.model)
        guide = self.guide # using Rui's guide function
        svi = SVI(self.model, guide, optim=Adam({"lr": lr}), loss=Trace_ELBO(), num_samples=1000)
        self.loss_history = []
        self.validation_loss_history = []
        
        # loss
        for epoch in range(num_epochs):
            total_loss = 0.
            total_validation_loss = 0.
            
            for x_batch, y_batch in dataloader:
                loss = svi.step(x_batch, y_batch)
                total_loss += loss
                #validation_loss = nn.MSELoss(prediction_mean, y_batch)
                #total_validation_loss += validation_loss
            
            avg_loss = total_loss / len(dataloader.dataset)
            #avg_validation_loss = total_validation_loss / len(dataloader.dataset)
            self.loss_history.append(avg_loss)
            #self.validation_loss_history.append(avg_validation_loss)
            
            if epoch % print_freq == 0:
                print('epoch {0:3d} avg loss {1:.4f}'.format(epoch, avg_loss))

I think you’ll want to use pyro.infer.Predictive to make new predictions, and then evaluate l1 or l2 loss by hand.