Hey guys,
I am training an SVI model and after n
epochs I introduce additional sample sites in the model’s forward method.
Although the training looks good, I observe that the model.state_dict().keys()
does not contain the sample sites which were introduced later during training. Same holds for pyro.get_param_store().keys()
.
Do I have to update my new sample sites manually (I cannot find the functionality for this, though) or should this be considered as bug? If yes, I’m happy to investigate it further and prepare a solution.
Some Pseudocode:
class SVIMODEL():
self.model = MODEL_CLASS()
self.guide = pyro.infer.autoguide.AutoNormal()
def training():
if epoch > N:
self.svi.step(warm_start_mode=False)
else:
self.svi.step(warm_start_mode=True)
class MODEL_CLASS():
def forward(self, warm_start_mode):
if warm_start_mode:
p = pyro.sample("p", dist.HalfCauchy(torch.ones(1)))
...
model = SVIMODEL()
model.training()
Thanks in advance!