Is using pyro.mask equivalent to setting requires_grad=False?

Hi all

I and @tulerpetontidae need to freeze optimisation updates for a subset of genes. Based on the description, pyro.mask should allow this. Would be good to understand it’s behaviour better.

Is pyro.mask(torch.zeros()) is equivalent to setting requires_grad=False? If not the same mechanism, does it lead to the same effect - freezing the values of parameters?

poutine.mask(mask=False) will zero out loss terms of sample statements in its context. In some circumstances this can be equivalent to setting requires_grad=False on some parameters. For example:

  • if you freshly initialize an optimizer (e.g. Adam)
  • and you have a parameter that affects pyro.sample statements inside a poutine.mask(mask=False) context
  • then that parameter will not be updated during training.

But beware there’s not always a 1-to-1 correspondence between parameters and loss terms:

  • some autoguides like AutoMultivariateNormal parameterize the joint posterior in a way that many parameters affect all downstream samples, so it’s tricky to freeze only a few parameters
  • in amortized inference, many parameters affect all downstream losses

You may be safer simply .detach_()ing the parameters you want to freeze. You can easily detach an entire (sub-)nn.Module by

for p in guide.submodule_i_want_to_detach.parameters():
    p.detach_()