In pyro, is there any way to optimize groups of params alternatingly?

for example, I have two groups of params,

one group (g1) containing the params induced from the guide,

another group (g2) containing the hyper-parameters of the model

What I want to do is the following steps,

optimize the ELBO w.r.t. g1 for some iterations,

optimize the ELBO w.r.t. g2 for some iterations,

go back to 1

To be more concrete, I am trying to implement (Gaussian Processes Regression Networks) with pyro, however, If I optimize the random variables (f and w in the paper) and hyper-parameters (for Gaussian Process Kernels) at the same time (with Adam), it stuck at some poor local minimum (just my guess, the fit is OK at the beginning, however the amplitude of the fit decrease as ELBO is getting better ). I checked the Matlab implementation by Nguyen, and they are alternating between the random variables and hyper-parameters. That is why I am asking this question.

Any suggestions on the implementation of GPRN is also appreciated:)

Hi @ruijiang, First note that alternating optimization may not be the best way to improve convergence. You might consider reformulating the model or guide, changing your parametric families, or warming up by optimizing only one set of parameters (say the guide params) followed by joint optimization.

To do either alternating or warmup optimization, you can use poutine.block, e.g.

def model():
p = pyro.param("model_param", ...)
...
def guide():
q = pyro.param("guide_param", ...)
...
# optimize only model params
optim = Adam({'lr': 0.001})
elbo = Trace_ELBO()
svi = SVI(poutine.block(model, hide=["guide_param"]), guide, optim, elbo)
for _ in range(100):
svi.step()
# optimize only guide params
optim = Adam({'lr': 0.001})
elbo = Trace_ELBO()
svi = SVI(model, poutine.block(guide, hide=["model_param"]), optim, elbo)
for _ in range(100):
svi.step()
# optimize both params
optim = Adam({'lr': 0.001})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, elbo)
for _ in range(1000):
svi.step()

Thanks fritzo for the help. I am now able to do alternating.

I have another question: is it OK that in model, we change the computational graph according to the input to the model as follows? What I want to do is to cache the kernel when hyperparameters are blocked from optimization:

class GPRN:
def __init__(self):
self.Ktril = None
def model(self, X, recomp_kernel=False):
if recomp_kernel or self.K is None:
self.Ktril = kern_tril(X)
pyro.sample("f", dist.MultivariateNormal(loc=torch.zeros(N), scale_tril=self.Ktril)
else:
pyro.sample("f", dist.MultivariateNormal(loc=torch.zeros(N), scale_tril=self.Ktril.detach())
# optimze the hyperparameter
svi_hyper.step(X, recomp_kernel=True)
# optimize the rest
svi_rest.step(X, recomp_kernel=False)

For GPRN, I played with the MATLAB code, it also stucks at some trivial solution where sigma_noise is very big with very low amplitude of the fitting, exactly the same situation as SVI. I will reformulate the model/guide, perhaps marginalize out f or w, so as to reduce the number of parameters and ease the initialization.