_partially_ clear global parameter store, autoguide name collisions, and name scoping

This is kind of two questions, but they are fairly related:

The first question: I’m working through some exercises that involves making multiple models, training each separately, and then comparing the results of the different models. However, because pyro has a global name store, if I want to clear the parameters for a specific model, it appears as though I must clear the parameters for ALL models with pyro.clear_param_store(). Is there any way to select which parameters I want to clear, and keep the rest?

Second question: I have started using autoguides for setting up quick inference problems because they make things really easy and I can make a guide in a single line. Great! The problem is that when I look at pyro.get_param_store(), the names of the parameters are things like AutoMultivariateNormal.loc and AutoMultivariateNormal.scale_tril, and it seems that this is the same for every autoguide that I make. This is causing name collisions when I make multiple models that use the same type of autoguide, and I am getting tensor shape mismatch errors because the parameters may be different shapes in different models. Is there any way to avoid the name collisions? I noticed that there used to be a prefix kwarg in older versions of pyro, but it looks like that has been removed. What is the proper way to avoid autoguide parameter name collisions?

I guess I also have a third, bonus question: what is the best way to handle these name collision problems in general? Since parameter names are handled at a global level, it makes it very difficult to properly re-use code, since you have to ensure that all the names are distinct. Like if I have two identical models I want to train on different data sets, I need to make a brand-new model/guide for the second one with slightly different parameter names. Is there any way to enforce some kind of local scoping?

You can use the dict interface of the param store:

store = pyro.get_param_store()
del store["param_name_1"]
del store["param_name_2"]

You could wrap each of your autoguides in a PyroModule, e.g.

guide1 = PyroModule("guide1")
guide1.guide = AutoNormal(model)
svi = SVI(model, guide1.guide, ...)
guide2 = PyroModule("guide2")
guide2.guide = AutoMultivariateNormal(model)

I’d recommend using PyroModule. Let us know if you have any issues or further questions. We don’t have a lot of PyroModule examples yet, but you can see usage by grepping around the Pyro codebase to see internal usage.

Good luck!

1 Like