How to get Validation Loss in SVI

Hello,

My understanding is that training a model using SVI is very similar to training a neural network. One of the useful outputs when training a neural net is the train vs. validation learning curves. This means evaluating the model at each epoch to get the training loss and validation loss (using in and out-of-sample data, respectively), and get a plot similar to the one below:

I would like to get the same 2 curves for my SVI training. The training loss is straight forward as we only need to do svi_results.losses.

For the validation loss I’m struggling a bit though. When we use the svi.run() method, we can only pass on training data so I’m guessing that the only way to get this loss is by constructing a custom training loop, where probably predictive class is involved to evaluate the loss every epoch. What is the most efficient way of doing this? Is there any implementation out there that I missed to do this?

I didn’t see any resources related to this particular problem so if anyone can point me to any material I’d also appreciate it.

Thank you!

I generally use the SVI.step() method directly, rather than SVI.run(). Writing an explicit loop you can explicitly compute loss for two different subsets of data:

train_data = ...
test_data = ...
elbo = Trace_ELBO()
svi = SVI(model, guide, Adam({"lr": 0.01}), elbo)
for epoch in range(num_epochs):
    for batch in my_minibatcher(train_data):
        batch_loss = svi.step(batch)
    with torch.no_grad():  # for logging only
        train_loss = elbo.loss(model, guide, train_data) # or average over batch_loss
        test_loss = elbo.loss(model, guide, test_data)
    print(train_loss, test_loss)

Is there any way I could view the resulting code you have @crisgompec as I have run into a similar problem recently.