I would like to compare the results of using different guides for the same model (or the results from different models with different guides). Similar to this post (Multiple guides for same model) it is of course possible to clear the global parameter store between runs of different models/guides, but that makes interactive comparison of the models/guides rather cumbersome. I’ve been using the following context manager to create multiple parameter stores (i.e. within the scope, a particular “local” parameter store is used rather than the “global” one).
Is this a sensible approach or is there something more straightforward that I’ve been missing?
import pyro
from pyro.params.param_store import ParamStoreDict
class paramstore_scope:
"""
Context manager for using multiple parameter stores within the same process.
"""
GLOBAL_STORE = None
def __init__(self, paramstore=None):
self.paramstore = paramstore or ParamStoreDict()
def __enter__(self):
if self.GLOBAL_STORE is not None:
raise RuntimeError('paramstore scopes cannot be nested')
# Store the global paramstore.
self.__class__.GLOBAL_STORE = pyro.poutine.runtime._PYRO_PARAM_STORE
self._setup_paramstore(self.paramstore)
return self
def __exit__(self, *args):
# Restore the global paramstore.
self._setup_paramstore(self.GLOBAL_STORE)
self.__class__.GLOBAL_STORE = None
def _setup_paramstore(self, paramstore):
# Replace all imports of the global param store with our version.
pyro.poutine.runtime._PYRO_PARAM_STORE = \
pyro.primitives._PYRO_PARAM_STORE = \
pyro.nn.module._PYRO_PARAM_STORE = \
paramstore
# Replace static use of the imported param store.
pyro.primitives._param = \
pyro.poutine.runtime.effectful(paramstore.get_param, type="param")
An example looks like this.
with paramstore_scope() as scope1:
# Do your inference for model 1 (which will leave the global paramstore unchanged).
with paramstore_scope() as scope2:
# Do your inference for model 2 (which will leave the global paramstore unchanged).
# Interactively explore the different models.