Multiple guides/models in the same process

I would like to compare the results of using different guides for the same model (or the results from different models with different guides). Similar to this post (Multiple guides for same model) it is of course possible to clear the global parameter store between runs of different models/guides, but that makes interactive comparison of the models/guides rather cumbersome. I’ve been using the following context manager to create multiple parameter stores (i.e. within the scope, a particular “local” parameter store is used rather than the “global” one).

Is this a sensible approach or is there something more straightforward that I’ve been missing?

import pyro
from pyro.params.param_store import ParamStoreDict


class paramstore_scope:
    """
    Context manager for using multiple parameter stores within the same process.
    """
    GLOBAL_STORE = None
    
    def __init__(self, paramstore=None):
        self.paramstore = paramstore or ParamStoreDict()
        
    def __enter__(self):
        if self.GLOBAL_STORE is not None:
            raise RuntimeError('paramstore scopes cannot be nested')
        # Store the global paramstore.
        self.__class__.GLOBAL_STORE = pyro.poutine.runtime._PYRO_PARAM_STORE
        self._setup_paramstore(self.paramstore)
        return self
        
    def __exit__(self, *args):
        # Restore the global paramstore.
        self._setup_paramstore(self.GLOBAL_STORE)
        self.__class__.GLOBAL_STORE = None
        
    def _setup_paramstore(self, paramstore):
        # Replace all imports of the global param store with our version.
        pyro.poutine.runtime._PYRO_PARAM_STORE = \
            pyro.primitives._PYRO_PARAM_STORE = \
            pyro.nn.module._PYRO_PARAM_STORE = \
            paramstore
        # Replace static use of the imported param store.
        pyro.primitives._param = \
            pyro.poutine.runtime.effectful(paramstore.get_param, type="param")

An example looks like this.

with paramstore_scope() as scope1:
    # Do your inference for model 1 (which will leave the global paramstore unchanged).

with paramstore_scope() as scope2:
    # Do your inference for model 2 (which will leave the global paramstore unchanged).

# Interactively explore the different models.

hi @tillahoffmann ,

we don’t currently have a good mechanism to do what you want but i believe your suggested mechanism is a good one. can you please open an issue to discuss further? it may be that your code snippet could already be the starting point for a useful PR.

thanks!

Sounds good. In case someone else stumbles across this thread, the issue is here.

1 Like