Hello
I made a Multi-output Gaussian Process Model using the implemented Coregionalization Kernel, and I’m currently using a Heteroscedastic approach, such that I sample the variance at each time step, for each of the output dimensions. The multi-output part works fine when modelling several time series, but I noticed something odd when sampling the variance. I can estimate the variance well on the first time series (g0) in the code below, but not the variance of the second time series (g1). I ran some tests where i swapped around the ordering of g0 and g1 in the code, and noticed that the issue is now reversed, such that I can estimate g1, but not g0. I’m currently doing inference using SVI, and through further experiments, it seems that increasing the “num_particles” in the ELBO estimator fixes this issue slightly.
def model(self, X, y=None):
self.set_mode("model")
N = X.size(0)
N1 = sum(X[:,1]==1)
N2 = sum(X[:,2]==1)
Kff = self.f_kernel(X).contiguous()
Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal
Lff = Kff.cholesky()
Kgg0 = self.g0_kernel(X[:(N1),:]).contiguous()
Kgg0.view(-1)[::(N1) + 1] += self.jitter # add jitter to the diagonal
Lgg0 = Kgg0.cholesky()
Kgg1 = self.g1_kernel(X[(N1):,:]).contiguous()
Kgg1.view(-1)[::(N2) + 1] += self.jitter # add jitter to the diagonal
Lgg1 = Kgg1.cholesky()
if len(X.shape)>1:
zero_loc_f = torch.zeros_like(X[:,0])
zero_loc_g0 = torch.zeros_like(X[:N1,0])
zero_loc_g1 = torch.zeros_like(X[:N2,0])
else:
zero_loc = torch.zeros_like(X)
f = pyro.sample("f", dist.MultivariateNormal(zero_loc_f, scale_tril=Lff))
g0 = pyro.sample("g0", dist.MultivariateNormal(zero_loc_g0, scale_tril=Lgg0))
g1 = pyro.sample("g1", dist.MultivariateNormal(zero_loc_g1, scale_tril=Lgg1))
g = torch.cat((g0, g1), 0)
if y is None:
return f
else:
return self.likelihood(f, g, y)
I don’t understand why the ordering of the variables should matter. There was also a forum post (Ordering of variables affects learning in AR model) which discussed a similar issue, but didn’t provide a solution. Do you guys have any idea why this is the case?