SVI Validation loss

Hello,

I’m a new user to Pyro and wondering how may I obtain validation loss, abs(prediction-true_value), from SVI? I tried to use ‘prediction_mean’ from model function but failed.

Code is attached for reference

def model(self, x, y=None):
# priors for network weights and biases
priors = {}
for name, param in self.network.named_parameters():
if ‘weight’ in name:
if self.zero_prior:
priors[name] = Normal(torch.zeros_like(param), torch.ones_like(param)).to_event(2)
else:
priors[name] = Normal(param, torch.ones_like(param)).to_event(2)
elif ‘bias’ in name:
if self.zero_prior:
priors[name] = Normal(torch.zeros_like(param), torch.ones_like(param)).to_event(1)
else:
priors[name] = Normal(param, torch.ones_like(param)).to_event(1)

    # lift module parameters to random variables sampled from the priors
    lifted_module = pyro.random_module('module', self.network, priors)
    # sample a regressor
    lifted_reg_model = lifted_module()
    
    with pyro.plate('map', len(x)):
        # run the regressor forward on data
        prediction_mean = lifted_reg_model(x).squeeze(-1)
        # condition on the observed data
        if y is not None: 
            y = y.squeeze(-1)
        pyro.sample('y', Normal(prediction_mean, self.noise_scale), obs=y)

    return prediction_mean

def guide(self, x, y=None):
    softplus  = nn.Softplus()

    # specify guide distributions for network weights and biases
    dists = {}
    for name, param in self.network.named_parameters():
        if 'weight' in name:
            loc = pyro.param('guide_loc_{}'.format(name), torch.randn_like(param))
            scale = pyro.param('guide_scale_{}'.format(name), softplus(torch.randn_like(param)), 
                               constraint=constraints.positive)
            dists[name] = Normal(loc, scale).to_event(2)
        elif 'bias' in name:
            loc = pyro.param('guide_loc_{}'.format(name), torch.randn_like(param))
            scale = pyro.param('guide_scale_{}'.format(name), softplus(torch.randn_like(param)), 
                               constraint=constraints.positive)
            dists[name] = Normal(loc, scale).to_event(1)
    
    # overload the parameters in the module with random samples from the guide distributions
    lifted_module = pyro.random_module('module', self.network, dists)
    # sample a regressor
    lifted_reg_model = lifted_module()

### SVI ###
def infer_svi(self, x, y, num_epochs=1000, lr=0.005, print_freq=1000, batch_size=10):
    # create dataloader 
    dataloader = DataLoader(TensorDataset(x, y), batch_size=batch_size, shuffle=True)
    
    # inference
    pyro.clear_param_store()
    # guide = AutoDiagonalNormal(self.model)
    guide = self.guide # using Rui's guide function
    svi = SVI(self.model, guide, optim=Adam({"lr": lr}), loss=Trace_ELBO(), num_samples=1000)
    self.loss_history = []
    self.validation_loss_history = []
    
    # loss
    for epoch in range(num_epochs):
        total_loss = 0.
        total_validation_loss = 0.
        
        for x_batch, y_batch in dataloader:
            loss = svi.step(x_batch, y_batch)
            total_loss += loss
            #validation_loss = nn.MSELoss(prediction_mean, y_batch)
            #total_validation_loss += validation_loss
        
        avg_loss = total_loss / len(dataloader.dataset)
        #avg_validation_loss = total_validation_loss / len(dataloader.dataset)
        self.loss_history.append(avg_loss)
        #self.validation_loss_history.append(avg_validation_loss)
        
        if epoch % print_freq == 0:
            print('epoch {0:3d} avg loss {1:.4f}'.format(epoch, avg_loss))