Coregionalize gives you a dense covariance matrix, instead of diagonal block matrix (formula A12). About component computation (formula A16), because we only use it to do matmul, it is enough to only store the original components with shape C x 7, and make a function matmul(components, weight) or matmul(components, components.t()) which do all the math under the hood.
I think that your approach is much more efficient. But no need to replicate a lot of math of that paper as in your code. In addition, the way you compute loss is also strange to me and it does not seem to include all the prior log_prob.
To do batch GPs and save computation cost, you just need to make a couple of change in my code
def model():
# if using my method, X has shape 7N x 10
Kff = self.kernel(self.X) # shape: 7N x 7N
Kff = block_diagonal_matrix(Kff) # shape: 7 x N x N
# if you want to use the original X with shape N x 3 together a nn.ModuleList with a sequence of 7 RBF kernels:
Kff = torch.stack([self.kernels[i](self.X) for i in range(7)]) # faster!
# any way, after this step, you get a matrix with shape 7 x N x N
Kff = Kff + self.noise * self.noise_scale # make sure noise_scale has shape 7 x N x N
Lff = Kff.cholesky() # cholesky support batch!
# if you use precision instead of noise, then do noise=1/precision and set a very high init precision as in the paper
# self.y has shape 7 x N
zero_loc = self.X.new_zeros(7, N)
return pyro.sample("weights", dist.MultivariateNormal(zero_loc, scale_tril=Lff).to_event(), obs=self.y)
def forward(...):
# make a loop to do condition for each kernel
weight1_loc, weight1_cov = conditional(Xnew, self.X, self.kernels[0], self.y[0], ...)