I and @tulerpetontidae need to freeze optimisation updates for a subset of genes. Based on the description, pyro.mask should allow this. Would be good to understand it’s behaviour better.
Is pyro.mask(torch.zeros()) is equivalent to setting requires_grad=False? If not the same mechanism, does it lead to the same effect - freezing the values of parameters?
poutine.mask(mask=False) will zero out loss terms of sample statements in its context. In some circumstances this can be equivalent to setting requires_grad=False on some parameters. For example:
if you freshly initialize an optimizer (e.g. Adam)
and you have a parameter that affects pyro.sample statements inside a poutine.mask(mask=False) context
then that parameter will not be updated during training.
But beware there’s not always a 1-to-1 correspondence between parameters and loss terms:
some autoguides like AutoMultivariateNormal parameterize the joint posterior in a way that many parameters affect all downstream samples, so it’s tricky to freeze only a few parameters
in amortized inference, many parameters affect all downstream losses
You may be safer simply .detach_()ing the parameters you want to freeze. You can easily detach an entire (sub-)nn.Module by
for p in guide.submodule_i_want_to_detach.parameters():
p.detach_()