I am doing SVI, using the lower level pyro.poutine.
I’d like to use difference optimizers with MixedMultiOptimizer
, but would like to be able to access the parameter string names in a cleaner way.
I want to use MixedMultiOptimizer
because I know how to get it working with pyro.optim.ClippedAdam
.
I compute the elbo like so
def compute_elbo_loss(model,guide,args,model_condition_data={}):
# http://pyro.ai/examples/effect_handlers.html#Example:-Variational-inference-with-a-Monte-Carlo-ELBO
conditioned_model = poutine.condition(model, data=model_condition_data) # https://docs.pyro.ai/en/stable/poutine.html#module-pyro.poutine.handlers
guide_trace = poutine.trace(guide).get_trace(args)
model_trace = poutine.trace(
poutine.replay(conditioned_model, trace=guide_trace)
).get_trace(args)
p = model_trace.log_prob_sum()
q = guide_trace.log_prob_sum()
elbo = p - q
elbo_loss = -elbo
return elbo_loss
And I train like so:
def train(model, guide, data):
adam = pyro.optim.ClippedAdam({'lr': lr1})
sgd = pyro.optim.SGD({'lr': lr2})
net_param_names = list(pyro.get_param_store().keys())[1:] # WANT TO CHANGE THIS PART ['net$$$linear_layers.0.weight','net$$$linear_layers.0.bias','net$$$linear_layers.2.weight','net$$$linear_layers.2.bias','net$$$linear_layers.4.weight','net$$$linear_layers.4.bias']
optim = MixedMultiOptimizer([(['x1_loc'], adam), (net_param_names, sgd)])
for batch in data:
with poutine.trace() as param_capture:
with poutine.block(hide_fn=lambda node: node["type"] != "param"):
elbo_loss = loss_and_trace(model, guide, batch, model_condition_data={})
params = {name: site['value'].unconstrained()
for name, site in param_capture.trace.nodes.items()
if site['type'] == 'param'}
optim.step(elbo_loss, params)
return losses
As you can see, the net_param_names
are accessed by getting all the param names, including x1_loc
with pyro.get_param_store().keys()
, and then indexing them manually. But in my actual problem, I have several nets. I’d like to just pick them out by using something like pyro.get_param_names_from_module('net')
, and it gives me a list ['net$$$linear_layers.0.weight','net$$$linear_layers.0.bias',...]
. Of course, the str label net$$$
is prepended to the respective parameter labels, so I could check the strings for what they start with.
For reference my model and guide are below (it’s just a toy example for this question!)
def model(args):
with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
loc_prior = args['loc_prior']
x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
x2 = pyro.sample('x2',dist.Normal(x1,1))
x3 = pyro.sample('x3',dist.Normal(x2,1), obs=args['x3_obs'])
return x1, x2, x3
def guide(args):
x1_loc = pyro.param('x1_loc',tensor(10.))
pyro.module("net", net)
x2_loc = net(tensor([1.]))
with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
x1 = pyro.sample('x1',dist.Normal(x1_loc,1))
x2 = pyro.sample('x2',dist.Normal(x2_loc,1))
return x1_loc, x2_loc