How to access guide parameters?


I define my guide, with parameters x_loc and x_scale, as:

def guide(data):
    x_loc = pyro.param("x_loc", torch.rand(N*3,))
    x_scale = pyro.param("x_scale", 0.5*torch.ones(N*3,), constraint=constraints.positive)
    x = pyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))

I want to access the parameters x_loc and x_scale to use them in the optimizer,

optimizer = torch.optim.Adam(PARAMETERS, {"lr": 0.001, "betas": (0.90, 0.999)})

I tried PARAMETERS = list(guide.parameters()), but it gives me the following error AttributeError: 'function' object has no attribute 'parameters'.

How can access the parameters in the guide?


For guides that are functions like yours, you can trace and extract the unconstrained params:

with poutine.trace(param_only=True) as tr:
constrained_params = [site["value"] for site in tr.trace.nodes.values()]
PARAMS = [p.unconstrained() for p in constrained_params]

For guides that are modules (e.g. autoguides) you can simply use .parameters:

guide = AutoNormal(model)
guide(data)  # need to initialize the guide first
PARAMS = list(guide.parameters())

Another way is to grab all parameters from the param store:

constrained_params = list(pyro.get_param_store().values())
PARAMS = [p.unconstrained() for p in constrained_params]
1 Like

Thank you @fritzo!