Hi all,
I am running the tutorial on the low level SVI loop (Customizing SVI objectives and training loops — Pyro Tutorials 1.8.4 documentation) to better understand the svi.step()
function and also to integrate Pyro
with Pytorch Lightning
(which needs a torch.optim
and not a pyro.optim
object).
The loop does its computation but the loss does not decrease (it stays stable) which may certainly be due to my misunderstanding of the svi.step()
mechanics.
So far my loop looks like this:
from pyro import poutine
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
for epoch in range(10):
for i, data in enumerate(train_loader, 0):
x,y = data
# Trace the sampling sites
with poutine.trace(param_only=True) as param_capture:
loss = loss_fn(mnistmodel, guide, x, y)
loss.backward()
# Get the parameters
params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values())
# Perform gradient step and empty the gradients
optimizer = torch.optim.Adam(params, lr=0.001)
optimizer.step()
optimizer.zero_grad()
print(loss / x.shape[0])
My understanding is that the loop:
- gets a batch of training data from the DataLoader
- calculate the loss while collecting the values of each parameters every time they are sampled (with the
trace
feature) - calculates the backward gradients over the learning weights
- tells the optimizer to perform one learning step
- zeros the optimizer’s gradients
Thank you very much for your help already!
If needed I can post more of the code but when doing the high level svi.step()
the model trains normally:
svi = SVI(mnistmodel, guide, adam, loss=Trace_ELBO())
for epoch in range(10):
for i, data in enumerate(train_loader, 0):
x,y = data
loss = svi.step(x, y)
print(loss / x.shape[0])