SVI does nothing when guide is a class object

Hey all,
I noticed a confusing behavior when realizing my guide as a class object. Consider the following toy example:

import pyro
import torch

initial_value_of_a = torch.tensor([1.0])

class guide_class():
    def __init__(self, a = initial_value_of_a):
          self.a = pyro.param('a', a)
    
    def __call__(self):
          # If I uncomment the following line everything works like a charm
          # pyro.param('a')
          return pyro.sample('z', pyro.distributions.Normal(self.a, 1.0))


guide = guide_class()    

After executing this code ‘a’ is listed in the parameter store. However, an SVI will not update the parameter ‘a’. You can check for instance that the following lines do not throw an error:

model = lambda: pyro.sample('z', pyro.distributions.Normal(0.0, 1.0))
svi = pyro.infer.SVI(model, guide, optim=pyro.optim.SGD({'lr': 1e-3}),  loss= pyro.infer.Trace_ELBO()) 

for i in range(100):
    svi.step()

assert guide.a == initial_value_of_a

Now, even more confusing, I noticed that ‘a’ is suddenly updated by the code above if I uncomment the line pyro.param('a') in the _ call _ method of guide_class. This behavior is really puzzling me, especially since I wouldn’t expect pyro.param('a') to have any effect as ‘a’ was already introduced by the _ init _ method.

Does anyone have an idea why the SVI wasn’t changing ‘a’ in the first place?

Hi @JoeMart, by calling pyro.param('a', a), you are registering a parameter to a global ParamStoreDict. So your self.a is registered to that global dict. In Pyro, SVI will only update parameters with pyro.param(...) statement, so if you comment out pyro.param('a'), that parameter will not be updated. If you don’t want pyro.param('a') affects self.a, just simply clone+detach it at init method,

self.a = pyro.param('a', a).detach().clone()

I think I got it, the pyro.param(...) has to appear explicitely in the definition of the guide even if the variable is already declared elsewhere as a parameter. So it actually had nothing to do with the class definition after all, thanks a lot!

to shed a little more light into the black box, pyro traces the guide function during execution, which means that it executes the guide and records the params and samples. since you declared your param outside the scope of your guide, when guide() is called, the program has no knowledge that self.a came from a param statement.