Hey all,
I noticed a confusing behavior when realizing my guide as a class object. Consider the following toy example:
import pyro
import torch
initial_value_of_a = torch.tensor([1.0])
class guide_class():
def __init__(self, a = initial_value_of_a):
self.a = pyro.param('a', a)
def __call__(self):
# If I uncomment the following line everything works like a charm
# pyro.param('a')
return pyro.sample('z', pyro.distributions.Normal(self.a, 1.0))
guide = guide_class()
After executing this code ‘a’ is listed in the parameter store. However, an SVI will not update the parameter ‘a’. You can check for instance that the following lines do not throw an error:
model = lambda: pyro.sample('z', pyro.distributions.Normal(0.0, 1.0))
svi = pyro.infer.SVI(model, guide, optim=pyro.optim.SGD({'lr': 1e-3}), loss= pyro.infer.Trace_ELBO())
for i in range(100):
svi.step()
assert guide.a == initial_value_of_a
Now, even more confusing, I noticed that ‘a’ is suddenly updated by the code above if I uncomment the line pyro.param('a')
in the _ call _ method of guide_class
. This behavior is really puzzling me, especially since I wouldn’t expect pyro.param('a')
to have any effect as ‘a’ was already introduced by the _ init _ method.
Does anyone have an idea why the SVI wasn’t changing ‘a’ in the first place?