In Pytorch we track additional losses. I don’t see an option to do a similar thing in Pyro. How would you recommend doing this?
As an example: Our group developed the tutorial on scanvi - scANVI: Deep Generative Modeling for Single Cell Data with Pyro — Pyro Tutorials 1.8.4 documentation.
Is there any way to include into losses the loss from pyro.sample("x", x_dist.to_event(1), obs=x)
:
for epoch in range(num_epochs):
losses = []
# Take a gradient step for each mini-batch in the dataset
for x, y in dataloader:
if y is not None:
y = y.type_as(x)
loss = svi.step(x, y)
losses.append(loss)
# Tell the scheduler we've done one epoch.
scheduler.step()
print("[Epoch %02d] Loss: %.5f" % (epoch, np.mean(losses)))