I and @tulerpetontidae need to freeze optimisation updates for a subset of genes. Based on the description,
pyro.mask should allow this. Would be good to understand it’s behaviour better.
pyro.mask(torch.zeros()) is equivalent to setting
requires_grad=False? If not the same mechanism, does it lead to the same effect - freezing the values of parameters?
poutine.mask(mask=False) will zero out loss terms of sample statements in its context. In some circumstances this can be equivalent to setting
requires_grad=False on some parameters. For example:
- if you freshly initialize an optimizer (e.g.
- and you have a parameter that affects
pyro.sample statements inside a
- then that parameter will not be updated during training.
But beware there’s not always a 1-to-1 correspondence between parameters and loss terms:
- some autoguides like
AutoMultivariateNormal parameterize the joint posterior in a way that many parameters affect all downstream samples, so it’s tricky to freeze only a few parameters
- in amortized inference, many parameters affect all downstream losses
You may be safer simply
.detach_()ing the parameters you want to freeze. You can easily detach an entire (sub-)
for p in guide.submodule_i_want_to_detach.parameters():