Deterministic warm-up training for VAEs

I’ve been doing some research using a VAE in Pyro, and it tends to have some stability issues at the beginning of training. I can lower the learning rate to prevent issues but then of course the training is slower.

One approach that tries to help with this is to use a “deterministic warm-up” period of training where the penalty from KL divergence is reduced (to zero or just fractionally). See this paper–their intuition for the improvement is that using priors will regularize many parameters to zero and they get stuck in a saddle point, whereas the warm-up period lets all the parameters get to a good starting point before the priors start to add regularization.

Because Pyro does some magic with my model, I’m not sure how to implement this. One approach would be to just build a similar model that I train using a normal PyTorch optimizer and then initialize my VAE with those parameters, but that feels cumbersome. Is there a better way to adjust the training loss of a Pyro SVI object so that I can have a period of warm-up? Perhaps input a custom loss into SVI for a period of training? This might be a generally useful enhancement–if so I and I can figure out how to implement it, I’ll write a PR.

something like this is done in the dmm tutorial (look for KL annealing). it is also touched on in the custom loss tutorial.

Thanks, I’ll check those out.