Skip updates of model parameters in alternating fashion


I’d like to estimate a model y=f(g(h(x))) and update the parameters in an alternating fashion, i.e.

  • in every even SVI step, I’d like to update the parameters of g and h, and
  • in every odd SVI step, I’d like to update the parameters of f and h.

Hence, h is always updated. One can assume that all f,g,h are simple non-linear functions.
If I understand correctly, unlike in Pytorch, I cannot simply set .required_grad to False for certain parameters in the Guide.

What I tried so far:

  1. First, I tried to use 2 different Guides, one for h, f and one for h, g. This does not make sense, because they do not share h.
  2. Second, I used an if statement in my guide and skipped the sample statements of g every other epoch, but somehow the parameters of g are still updated. Is that to be expected?

Are there any other possible solutions (or a fix for 2)? I am also happy to provide a minimum working example.

Thanks in advance!

I guess a simple way is to add a flag to your model/guide so that you can decide if you want to detach some params. Something like

def model(detach=False):
    p = pyro.param("p", ...)
    p = p.detach() if detach else p

then switch the flag for each svi step.