Thanks fritzo for the help. I am now able to do alternating.
I have another question: is it OK that in model, we change the computational graph according to the input to the model as follows? What I want to do is to cache the kernel when hyperparameters are blocked from optimization:
self.Ktril = None
def model(self, X, recomp_kernel=False):
if recomp_kernel or self.K is None:
self.Ktril = kern_tril(X)
pyro.sample("f", dist.MultivariateNormal(loc=torch.zeros(N), scale_tril=self.Ktril)
pyro.sample("f", dist.MultivariateNormal(loc=torch.zeros(N), scale_tril=self.Ktril.detach())
# optimze the hyperparameter
# optimize the rest
For GPRN, I played with the MATLAB code, it also stucks at some trivial solution where sigma_noise is very big with very low amplitude of the fitting, exactly the same situation as SVI. I will reformulate the model/guide, perhaps marginalize out f or w, so as to reduce the number of parameters and ease the initialization.