I am training an SVI model and after
n epochs I introduce additional sample sites in the model’s forward method.
Although the training looks good, I observe that the
model.state_dict().keys() does not contain the sample sites which were introduced later during training. Same holds for
Do I have to update my new sample sites manually (I cannot find the functionality for this, though) or should this be considered as bug? If yes, I’m happy to investigate it further and prepare a solution.
class SVIMODEL(): self.model = MODEL_CLASS() self.guide = pyro.infer.autoguide.AutoNormal() def training(): if epoch > N: self.svi.step(warm_start_mode=False) else: self.svi.step(warm_start_mode=True) class MODEL_CLASS(): def forward(self, warm_start_mode): if warm_start_mode: p = pyro.sample("p", dist.HalfCauchy(torch.ones(1))) ... model = SVIMODEL() model.training()
Thanks in advance!