How to change DEBUG output while training CEVAE?

Hello, I am running the CEVAE tutorial Example: Causal Effect VAE — Pyro Tutorials 1.8.4 documentation on Pyro 1.5.2 and I would like to change the way the losses are printed. I am using:

logging.getLogger("pyro").setLevel(logging.DEBUG)
logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)

I am assuming that this handles how CEVAE outputs the losses when calling fit. Is there a way to print alter this, for instance showing the loss every 10 epochs? Also I would like to know how I can switch it on and off, currently I am restarting the kernel and not calling the loggers…

Thank you in advance!

Hi @nichaz,

Well currently the loss is logged every log_every SVI steps, as controlled by the CEVAE.fit() method. If you want to log every n epochs, I’d recommend subclassing CEVAE and overriding with a customized .fit() method (forking the existing code) that logs exactly when you want to log.

I haven’t done this myself, but this stackoverflow post suggests you can use one of

logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.INFO)

to dynamically set logging level. Hope that helps!

1 Like

Hi @fritzo thank you for your help, that makes sense. I am not exactly sure how to implement it but will try it out. I have a couple of additional issues. After fitting the model and running cevae.ite(X_test) I can’t re-run ite(). If I try to run it on the train data I get this error:
RuntimeError: nvrtc: error: failed to open nvrtc-builtins64_110.dll. Make sure that nvrtc-builtins64_110.dll is installed correctly. nvrtc compilation failed:

Also, I wanted to extract the encoded representations from the Guide model after training. I am calling
with torch.no_grad():
test_enoded = cevae.guide.forward(X_test, t_test, y=y_test)

and I get a runtime error stating that Trace submodules cannot be called. Any clues, about what may be wrong?

I can create another post if it would be more helpful

Hi @nichaz, yes probably another post would have been helpful. Anyway my guess is that the errors you are seeing (RuntimeError: nvrtc and “Trace submodules cannot be called”) is probably due to the PyTorch jit somewhere. You might be able to work around by either disabling jit compilation (e.g. use Trace_ELBO instead of JitTrace_ELBO) or by updating PyTorch. Beyond that I’d check out the PyTorch forum and issue tracker.