Custom SVI + loss does not decrease

Hi all,

I am running the tutorial on the low level SVI loop (Customizing SVI objectives and training loops — Pyro Tutorials 1.8.4 documentation) to better understand the svi.step() function and also to integrate Pyro with Pytorch Lightning (which needs a torch.optim and not a pyro.optim object).

The loop does its computation but the loss does not decrease (it stays stable) which may certainly be due to my misunderstanding of the svi.step() mechanics.

So far my loop looks like this:

from pyro import poutine

loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

for epoch in range(10):
        for i, data in enumerate(train_loader, 0):

            x,y = data

            # Trace the sampling sites
            with poutine.trace(param_only=True) as param_capture:
                loss = loss_fn(mnistmodel, guide, x, y)
            loss.backward()

            # Get the parameters
            params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values())

            # Perform gradient step and empty the gradients
            optimizer = torch.optim.Adam(params, lr=0.001)
            optimizer.step()
            optimizer.zero_grad()

        print(loss / x.shape[0])

My understanding is that the loop:

  • gets a batch of training data from the DataLoader
  • calculate the loss while collecting the values of each parameters every time they are sampled (with the trace feature)
  • calculates the backward gradients over the learning weights
  • tells the optimizer to perform one learning step
  • zeros the optimizer’s gradients

Thank you very much for your help already!

If needed I can post more of the code but when doing the high level svi.step() the model trains normally:

svi = SVI(mnistmodel, guide, adam, loss=Trace_ELBO())

for epoch in range(10):
        for i, data in enumerate(train_loader, 0):
            x,y = data
            loss = svi.step(x, y)
        print(loss / x.shape[0])

please refer to this

This is exactly what I did at the beginning. I was trying to get the parameters using:

params = []

for m in mnistmodel.modules():
        for name, value in list(m.named_parameters(recurse=False)):
            params.append(value)

And then applied the tutorial:

optimizer = torch.optim.Adam(params, lr = 0.001)

for epoch in range(10):
        for i, data in enumerate(train_loader, 0):
            x,y = data
            loss = loss_fn(mnistmodel, guide, x, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(loss / x.shape[0])

However, the model was not training either and as far as I have understood the parameters need to be collected with the trace. But it may be my own misunderstanding!

you probably need to run the model once to instantiate the parameters. the parameters are created on the fly and do not exist before the model is run the first time

Do you mean having 2 loops such as:

params = []

for m in mnistmodel.modules():
        for name, value in list(m.named_parameters(recurse=False)):
            params.append(value)

optimizer = torch.optim.Adam(params, lr = 0.001)

# Instantiate the parameters
for epoch in range(1):
        for i, data in enumerate(train_loader, 0):
            x,y = data
            loss = loss_fn(mnistmodel, guide, x, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(loss / x.shape[0])

# Train the model
for epoch in range(10):
        for i, data in enumerate(train_loader, 0):
            x,y = data
            loss = loss_fn(mnistmodel, guide, x, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(loss / x.shape[0])

If so, I have tried but without much success.

I’ve also been looking at minipyro and the code is pretty similar to the first I posted. There might just be a slight difference I did not get that is messing with the training.

can you please paste a complete runnable script with import statements etc and the world’s simplest model? otherwise hard to know where you’re going wrong

Sorry, here is the whole code. Note that I have my custom DataLoader for the MNIST dataset - torchvision somehow is not supported by my system. You can download the MNIST dataset with torchvision and use my loader, it should work

  • Code for handling the MNIST dataset and code for the very simple model
import glob
import os
import numpy as np
import pytorch_lightning as pl
import albumentations as A

import torch
import torch.nn as nn

import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO
from pyro import poutine

from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image

from albumentations.pytorch import ToTensorV2

class MNISTLoader(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.samples = []
        self._init_dataset()

    def __getLabels__(self):
        return self.class_encode.classes_

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):

        img_path, label = self.samples[idx]
        image = Image.open(img_path)
        image = np.array(image)

        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return (image, int(label))

    def _init_dataset(self):

        files = glob.glob(self.img_dir + "/*")
        labels = []

        for file in files:
            label = file.split(".")[0].split("num")[-1]
            labels.append(label)
            self.samples.append((file, label))

class MNISTDataModule(pl.LightningDataModule):

    def train_val_dataloader(self):
        transform=A.Compose([A.Normalize(mean = (0.1307), std = (0.3081)),
                            ToTensorV2()])
        mnist_train = MNISTLoader("/Data/mnist/mnist_train/train", transform=transform)
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        mnist_train = DataLoader(mnist_train, batch_size=64)
        mnist_val = DataLoader(mnist_val, batch_size=64)
        return mnist_train, mnist_val

    def test_dataloader(self):
        transform=A.Compose([A.Normalize(mean = (0.1307), std = (0.3081)),
                            ToTensorV2()])
        mnist_test = MNISTLoader("/Data/mnist/mnist_test/test",  transform=transform)
        mnist_test = DataLoader(mnist_test, batch_size=64)
        return mnist_test


class MNISTModel(nn.Module):

    def __init__(self):
        super(MNISTModel, self).__init__()

        # MNIST images are (1,28,28)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x, y=None):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # layer 1
        x = self.layer_1(x)
        x = torch.relu(x)

        # layer 2
        x = self.layer_2(x)
        x = torch.relu(x)

        # layer 3
        x = self.layer_3(x)
        
        # Proba over labels
        x = torch.log_softmax(x, dim=1)

        return x
  • Code for adding sample sites:
mnistmodel = MNISTModel()
pyro.nn.module.to_pyro_module_(mnistmodel)

for m in mnistmodel.modules():
    for name, value in list(m.named_parameters(recurse=False)):
        print(f"Adding sample site {m._pyro_name} {name}")
        setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))
  • Instantiate the dataset, guide and the loss function and run the manual SVI
train_loader, val_loader = MNISTDataModule().train_val_dataloader()
guide = AutoDiagonalNormal(mnistmodel)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

for epoch in range(10):
        for i, data in enumerate(train_loader, 0):

            x,y = data

            with poutine.trace(param_only=True) as param_capture:
                loss = loss_fn(mnistmodel, guide, x, y)
            loss.backward()

            params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values())

            # Perform gradient step and empty the gradients
            optimizer = torch.optim.Adam(params, lr=0.001)
            optimizer.step()
            optimizer.zero_grad()

        print(loss / x.shape[0])

Thank you for your help!

as far as i can tell your model doesn’t have any parameters (what it has are sample statements). only your guide has parameters. you need to capture those parameters.

it also looks like you’re things in the universe of bayesian neural networks. i’d generally recommend using tyxe for that instead of pure pyro. as a forewarning, it’s very easy to get bad results with bayesian neural networks unless you’re pretty familiar with the technical details

Instead of the manual SVI, if I instead run

svi = SVI(mnistmodel, guide, adam, loss=Trace_ELBO())

for epoch in range(10):
        for i, data in enumerate(train_loader, 0):
            x,y = data
            loss = svi.step(x, y)
        print(loss / x.shape[0])

The model trains perfectly well and this is what I really would like to understand, what is the difference between my loop and this one?

I used TyXe and it works really well. However, for this specific project I would like to make Pyro communicate with Pytorch Lightning and for this I need to decompose the svi.

like i said your guide has parameters. SVI makes sure to trace/collect those. you are not doing so. consequently you learn nothing