AutoGuide Parameters - Initializing and Recycling

I have a SVI model that seems quite prone to finding local minima. I’ve taken to re-running the model several times with different seeds to understand the range of possible errors and to choose between possible results.

I’m using AutoDelta to create my guide function, and this is an issue as the first svi step (using JIT) can take a few minutes. This means if I run a reasonable number of seeds, I can be spending an hour or more just on the first svi step of each model.

What I’m wondering is if there is a way to snap the state of the model after the first SVI step, and then re-randomize the parameters, avoiding the JIT recompiles. With the re-initialized params I’m hoping to run just my SVI again to get another elbo trace, and throw that result in my bucket of possibilities.

Ideally I’d use the carefully chosen init_loc_fn s to regenerate starting tensors, just with a new rng seed. By stumbling around I can eventually run something that looks like:

guide.parts[1]._init_loc_fn(site) # guide is an AutoGuideList

but it feels like there must be a more straightforward way. My main question is:

  • How can I reinitialize existing parameters using the init_loc_fn?

A few other noob questions:

  • Is there a better way to get the sample sites than running a big trace/block combo every time, a la

     with poutine.trace() as sample_capture:
         # We use block here to allow tracing to record parameters only.
         with poutine.block(expose_fn=lambda msg: msg["type"] == "sample"):
             guide(inputs)
    
  • Is manually adjusting the dict elements of pstore.get_state() a reasonable way to override parameters? (I saw calling pyro.param() before the guide makes them is best, but what if they are already in the store?)

  • I can’t quite get a clear picture of the .unconstrained() aspect when it comes to params / samples.

    • It seems that the .unconstrained() output tensor is the actual site that everything uses as key, is that correct?
    • If I want to override values in the tensor, does that mean that I have to set values in the unconstrained space? How do I map values in the constrained space into that space?
    • What’s the general practice for how and when to use .unconstrained()?

Thanks for all the work on such a great environment!

Maybe this post wasn’t focused enough, as it’s not exactly setting reply records :slight_smile:

@fritzo you seem very busy, but since you did the autoguide initialization PR, maybe you could take a glance?

The question is: Given an AutoGuide and a set of init_loc_fn s, I’d like to run an SVI, and then reseed the rng and run again with new initial parameters to get a different result without forcing JIT compile.

Is that in the realm of possibility?

If I understand correctly, the PyTorch jit is very slow to compile your model, so the first step is taking inordinately long. In this case I’d try to

  1. speed up jit compilation by passing jit_options={"optimize": False} to your elbo constructor.
  2. re-initialize by deleting the params from the param store then running a non-jit guide.
data = ...
model = ...
guide = AutoDelta(model)
elbo = JitTrace_ELBO(jit_options={"optimize": False})
for seed in range(100):
    pyro.get_param_store().clear()  # reset parameters
    guide(data)  # initialize model w/o recompiling
    svi = SVI(model, guide, Adam({}), elbo)  # reset Adam params each seed
    for step in range(steps):
        svi.step(data)
    ...save params...

Other questions:

Is manually adjusting the dict elements of pstore.get_state() a reasonable way to override parameters?

The recommended interface is to use dict-like methods of the ParamStore object returned by pyro.get_param_store(), e.g.

store = pyro.get_param_store()
del store["my_param1"]
del store["my_param2"]
store["my_param1"] = torch.zeros(100)  # unconstrained
store.setdefault("my_param2", torch.ones(100),
                 constraint=constraints.positive)

Note that .unconstrained() is for internal use; you can completely avoid it by using the ParamStore interface and initializing with appropriate constraint= kwarg, but simply calling guide(...) should initialize correctly, if you’re using an AutoDelta.

Is there a better way to get the sample sites than running a big trace/block

trace/block is the recommended method.

Thanks for the reply @fritzo.

There remain a couple of things hidden from view that are a bit vexing. It’s not clear to me:

  • exactly when a JIT compile is triggered, and
  • why the state of the optimizers seems to persist even if a new PyroOptim object is used

The latter point is especially confusing, and to date I’ve always had to discard my AutoGuide fn and reconstruct it if I want to restart an optimization (not just for this problem). Otherwise the optimization always resumes from the point the last PyroOptim left off. This also has implications for reproducibility.

Here’s some skeletal code that illustrates both points:

elbo = JitTraceEnum_ELBO(max_plate_nesting=1, jit_options={"optimize": False})
guide = AutoDelta(poutine.block(model, expose=["a", "b"]))

with Timer('init step'):
    svi = SVI(model, guide, Adam({}), elbo) # reset Adam params each seed
    svi.step(...all the args...)

for seed in range(3):
    pyro.set_rng_seed(seed)
    pyro.get_param_store().clear()  # reset parameters

    # re-declaring the guide will reset the optimization, 
    # but will lose the JIT-compiled elbo
#     guide = AutoDelta(poutine.block(model, expose=["a", "b"]))

    guide(...all the args...)  # initialize model w/o recompiling
    svi = SVI(model, guide, Adam({}), elbo) # reset Adam params each seed
    with Timer('first step'):
        svi.step(...all the args...)
    with Timer('later steps'):
        for step in range(5):
            loss = svi.step(...all the args...)
        print(F"loss for seed {seed}: {loss}")

(Timer just prints wall clock time of block)

As noted, the redeclaration of the guide is the only way I can get N different optimization runs, otherwise it happily continues from the previous steps. But… redeclaring the guide causes the JIT to trigger on the first step, which is what we are trying to avoid here. Clearly you anticipated this by trying to reset the Adam parameters, but your method isn’t working for me.

Any general comments about when JIT is triggered and how the PyroOptim persist behind the scenes also appreciated!

Said another way, the code above (without resetting the guide each seed loop) is effectively one optimization. How that happens despite the pyro.get_param_store().clear() baffles me.

loss for seed 0: 694.05810546875
loss for seed 1: 693.53662109375
loss for seed 2: 693.03369140625
loss for seed 3: 692.5475463867188
...
loss for seed 9: 689.9354858398438
loss for seed 10: 689.5413208007812
loss for seed 11: 689.1569213867188

With the guide reset in the seed loop each run is independent:

loss for seed 0: 713.1170654296875
loss for seed 1: 752.1927490234375
loss for seed 2: 727.0341796875
loss for seed 3: 719.560302734375
loss for seed 4: 731.3297729492188
loss for seed 5: 724.9409790039062
...

Hmm I think the problem may be state leakage in AutoDelta. Can you try additionally clearing the guide.prototype_trace attribute:

for seed in range(3):
    pyro.set_rng_seed(seed)
    pyro.get_param_store().clear()
    guide.prototype_trace = None  # force guide to reinitialize parameters

    ...

I believe what’s happening is that (1) indeed the param store is being cleaned, but (2) by reusing the guide you switch initialization behavior from init_to_median() in the first seed to init_to_sample() in all subsequent seeds. Let me know your results and we can try to make this less error-prone.

Brilliant @fritzo - clearing the prototype_trace cracked the case. I’ll verify further tomorrow, but it does appear that I’m getting independent, reproducible runs without triggering JIT (and without recreating the guide every time). :tada:

I’m still a little confused about how the optimization just seemed to progress despite giving a fresh Adam object… Something related to your init_to_sample() theory?

And a quick question: if my guide is an AutoGuideList, do I still just guide.prototype_trace = None, or do I have to iterate that over all of the guide.parts?

Yes, sorry, in case of an AutoGuodeList you’ll need to iterate over the parts and set each part.prototype_trace = None. I’ll think about possible changes that could make this easier…

So it turns out for AutoGuideList that you need to set guide.prototype_trace = None and each part.prototype_trace.

(Maybe that’s what you meant, but not how I read it.)