HyperParameter Optimization in Pyro

I am trying to optimize my pyro program (MNIST classification) wrt the following Hyper parameters (and example grid values):

  • AutoGuide init_scale: [10,1,0.1,0.01]
  • BatchSize: [10,100,500,5000]
  • learning rate: [0.01, 0.001]

The pyro model creation is enclosed into a parametrized function, Run(...) and am passing various parameters as input to this function in a loop (the ‘usual’ ML way ). The output measures the accuracy of the argmax of the samples on a test set (which isnt very bayesian, but helps me see which params help)
I noticed that there is some parameter leakage. For example the sum of the init_scale param at the beginning of the run with a new set of parameters, is the same as the the sum of the init_Scale param at the end of the previous run. Even after doing a pyro.clear_param_store(). the relevant part of my code is as follows:

  model = MakeModel()
  guide = autoguide.AutoDiagonalNormal(model,  init_scale=initial_variance)

  _ = guide(torch.rand(1, *input_dim),1) # input dim is the shape of the input images 
 #result_dict stores parameter values to be returned.
result_dict['GuideLocBefore'] = pyro.param("AutoDiagonalNormal.loc").abs().sum().item()
result_dict['GuideScaleBefore'] = pyro.param("AutoDiagonalNormal.scale").sum().item()

####################################################
# train the model
pyro.clear_param_store() # this should clean the slate, right?

_ = guide(torch.rand(1, *input_dim),1)
result_dict['GuideLocBefore1'] = pyro.param("AutoDiagonalNormal.loc").abs().sum().item()
result_dict['GuideScaleBefore1'] = pyro.param("AutoDiagonalNormal.scale").sum().item()

What I see is, GuideScaleBefore1 = GuideScaleBefore = Guide scale from the end of the previous call to the Run() function.
Quesions:
How do we stop this leakage?
Also is this the way to do Hyperparam optimization?

Many thanks in advance!

please see clear_param_store. also please see ParamStoreDict

Thanks @martinjankowiak!
My understanding is that the global parameter store is ‘global’, so the values in it should persist over successive runs of the Run() function and that clear_param_store clears this parameter store. This explains why result_dict['GuideScaleBefore'] has scale values from the previous iteration.

But, in the code snippet above, why does result_dict['GuideScaleBefore1'] still have the scale and loc values even after the clear_param_store() statement:

pyro.clear_param_store() # this should clean the slate, right?

Am I clearing the param store correctly?

when you do operations like .abs().sum().item() you create an entirely new object (in this case a float) that is not tied in anyway to the original parameter. thus any such object will persist. if you want to keep the original parameters around (without effecting subsequent optimization runs) you should do something like

result_dict['GuideScaleBefore1'] = pyro.param("AutoDiagonalNormal.scale").clone()

An Update: I ran the HPO without the part before the pyro.clear_param_store() and it worked. i.e. loc and scale parameters from the previous run didn’t appear in the next.

Thanks for the help @martinjankowiak!