Optimization - accessing pyro.module param string labels for MixedMultiOptimizer

I am doing SVI, using the lower level pyro.poutine.

I’d like to use difference optimizers with MixedMultiOptimizer, but would like to be able to access the parameter string names in a cleaner way.

I want to use MixedMultiOptimizer because I know how to get it working with pyro.optim.ClippedAdam.

I compute the elbo like so

def compute_elbo_loss(model,guide,args,model_condition_data={}):
  # http://pyro.ai/examples/effect_handlers.html#Example:-Variational-inference-with-a-Monte-Carlo-ELBO
  conditioned_model = poutine.condition(model, data=model_condition_data) # https://docs.pyro.ai/en/stable/poutine.html#module-pyro.poutine.handlers
  guide_trace = poutine.trace(guide).get_trace(args)
  model_trace = poutine.trace(
          poutine.replay(conditioned_model, trace=guide_trace)
  p = model_trace.log_prob_sum() 
  q = guide_trace.log_prob_sum()
  elbo = p - q
  elbo_loss = -elbo
  return elbo_loss

And I train like so:

def train(model, guide, data):
  adam = pyro.optim.ClippedAdam({'lr': lr1})
  sgd = pyro.optim.SGD({'lr': lr2})
  net_param_names = list(pyro.get_param_store().keys())[1:] # WANT TO CHANGE THIS PART ['net$$$linear_layers.0.weight','net$$$linear_layers.0.bias','net$$$linear_layers.2.weight','net$$$linear_layers.2.bias','net$$$linear_layers.4.weight','net$$$linear_layers.4.bias'] 
optim = MixedMultiOptimizer([(['x1_loc'], adam), (net_param_names, sgd)])
  for batch in data:
      with poutine.trace() as param_capture:
          with poutine.block(hide_fn=lambda node: node["type"] != "param"):
              elbo_loss = loss_and_trace(model, guide, batch, model_condition_data={})
      params = {name: site['value'].unconstrained()
        for name, site in param_capture.trace.nodes.items()
        if site['type'] == 'param'}
      optim.step(elbo_loss, params)
  return losses

As you can see, the net_param_names are accessed by getting all the param names, including x1_loc with pyro.get_param_store().keys(), and then indexing them manually. But in my actual problem, I have several nets. I’d like to just pick them out by using something like pyro.get_param_names_from_module('net'), and it gives me a list ['net$$$linear_layers.0.weight','net$$$linear_layers.0.bias',...]. Of course, the str label net$$$ is prepended to the respective parameter labels, so I could check the strings for what they start with.

For reference my model and guide are below (it’s just a toy example for this question!)

def model(args):
  with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
    loc_prior = args['loc_prior']
    x1 = pyro.sample('x1',dist.Normal(loc_prior,1))
    x2 = pyro.sample('x2',dist.Normal(x1,1))
    x3 = pyro.sample('x3',dist.Normal(x2,1), obs=args['x3_obs'])
  return x1, x2, x3

def guide(args):
  x1_loc = pyro.param('x1_loc',tensor(10.))
  pyro.module("net", net)
  x2_loc = net(tensor([1.]))
  with pyro.plate('mini_batch',args['x3_obs'].shape[0]):
    x1 = pyro.sample('x1',dist.Normal(x1_loc,1))
    x2 = pyro.sample('x2',dist.Normal(x2_loc,1))
  return x1_loc, x2_loc

i’m not sure if i understand your question but can you do a trace like here