ParamStore not correctly updated

Hey guys,

I am training an SVI model and after n epochs I introduce additional sample sites in the model’s forward method.
Although the training looks good, I observe that the model.state_dict().keys() does not contain the sample sites which were introduced later during training. Same holds for pyro.get_param_store().keys().

Do I have to update my new sample sites manually (I cannot find the functionality for this, though) or should this be considered as bug? If yes, I’m happy to investigate it further and prepare a solution.

Some Pseudocode:

class SVIMODEL():
  self.model = MODEL_CLASS()
  self.guide = pyro.infer.autoguide.AutoNormal()
  
  def training():
    if epoch > N:
      self.svi.step(warm_start_mode=False)
    else:
      self.svi.step(warm_start_mode=True)


class MODEL_CLASS():
  def forward(self, warm_start_mode):  
    if warm_start_mode:
      p = pyro.sample("p", dist.HalfCauchy(torch.ones(1)))
      ...


model = SVIMODEL()
model.training()

Thanks in advance!

Hi @Madmore, although AutoNormal is not intended to support dynamic models, your workflow should be possible if you reset the guide just before your N+1th epoch. Here’s pseudocode:

EDIT this does not work; see below for an alternate solution.

def training(self, epoch):
    if epoch > N:
        if epoch == N + 1:
            self.guide.prototype_trace = None  # reset guide
        self.svi.step(warm_start_mode=False)
    else:
        self.svi.step(warm_start_mode=True)

Let me know if that doesn’t work :smile:

1 Like

Thanks for the reply!

I just played a bit with it, but unfortunately, using self.guide.prototype_trace = None resets also the parameters of all other sample sites (which I want to keep to not start “retraining” them from scratch). I see two different options:

  1. I think adding the new sample sites at step N+1 manually to the prototype_trace dict is a bit hacky and error-prone.
  2. Write a custom guide (because I am in a model exploration phase, I would prefer to keep an AutoGuide to analyze many models without modifying the guide all the time)

There is no other internal function I could exploit?

Hi @Madmore, thanks for trying that out :thinking: What if instead we try to simply recreate a guide (and necessarily the SVI object since that stores the guide)? We can store the optimizer so our moments and learning rate schedules remain intact:

class SVIMODEL:
    def __init__(self):
        self.model = MODEL_CLASS()
        self.guide = AutoNormal(self.model)
        self.optim = Adam(...)  # or whetever
        self.svi = SVI(self.model, self.guide, self.optim, ...)
  
    def training():
        if epoch > N:
            if epoch == N + 1:
                # Reset the guide since model structure has changed.
                self.guide = AutoNormal(self.model)
                self.svi = SVI(self.model, self.guide, self.optim, ...)
            self.svi.step(warm_start_mode=False)
        else:
            self.svi.step(warm_start_mode=True)

This version seems to work for me in a unit test I’ve added to Pyro.

1 Like

Actually I think AutoNormalMessenger should be a drop-in replacement for AutoNormal while fully supporting dynamic models. If you get a chance to try this out, I’d love to hear your results.

1 Like


Recreating the guide and svi object works like a charm, thanks! Good idea.

I’ll try out the AutoNormalMessenger and give you an update soon!

I just went through your unit test and I think the behaviour I saw is a bit unexpected. Two comments:

  1. In your unit test you are overwriting the guide with the AutoNormal of another model, but the loc estimate is preserved. I’d rather expect the loc/scale parameters to be reset as well, if I create a new guide instance?
  2. If I run the following code the print statement should return a 0 (which is the median of a Normal(0,1) distribution), but it turns some random variable close to zero. How come this is not zero?
import pyro
import torch
import pyro.distributions as dist
from pyro.infer.autoguide.initialization import init_to_median

def model1():
    x = pyro.sample("x", dist.Normal(0, 1))

guide = pyro.infer.autoguide.AutoNormal(model1, init_scale=1e-5, init_loc_fn=init_to_median)
guide()  # initialize

print(guide.locs.x)

P.S.: Small typo in the docstring of the init_to_median function here: I guess it should be median instead of mean. Always a bit hesitant to open a PR for a single typo (maybe you have a branch for that) :wink:

  1. The guide overriding workaround makes use of Pyro’s global param store. It’s kind of gross, and the AutoNormalMessenger solution seems cleaner.

  2. init_to_median computes an empirical median over some fixed number of samples. This is intended; it is important to have a little noise so that in hierarchical models outer scale random variables don’t collapse due to zero-variance downstream variables.

Don’t hesitate to open a PR even if fixes only one word! PRs are the only way docs and code improves :smile:

1 Like