I made a Multi-output Gaussian Process Model using the implemented Coregionalization Kernel, and I’m currently using a Heteroscedastic approach, such that I sample the variance at each time step, for each of the output dimensions. The multi-output part works fine when modelling several time series, but I noticed something odd when sampling the variance. I can estimate the variance well on the first time series (g0) in the code below, but not the variance of the second time series (g1). I ran some tests where i swapped around the ordering of g0 and g1 in the code, and noticed that the issue is now reversed, such that I can estimate g1, but not g0. I’m currently doing inference using SVI, and through further experiments, it seems that increasing the “num_particles” in the ELBO estimator fixes this issue slightly.
def model(self, X, y=None): self.set_mode("model") N = X.size(0) N1 = sum(X[:,1]==1) N2 = sum(X[:,2]==1) Kff = self.f_kernel(X).contiguous() Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal Lff = Kff.cholesky() Kgg0 = self.g0_kernel(X[:(N1),:]).contiguous() Kgg0.view(-1)[::(N1) + 1] += self.jitter # add jitter to the diagonal Lgg0 = Kgg0.cholesky() Kgg1 = self.g1_kernel(X[(N1):,:]).contiguous() Kgg1.view(-1)[::(N2) + 1] += self.jitter # add jitter to the diagonal Lgg1 = Kgg1.cholesky() if len(X.shape)>1: zero_loc_f = torch.zeros_like(X[:,0]) zero_loc_g0 = torch.zeros_like(X[:N1,0]) zero_loc_g1 = torch.zeros_like(X[:N2,0]) else: zero_loc = torch.zeros_like(X) f = pyro.sample("f", dist.MultivariateNormal(zero_loc_f, scale_tril=Lff)) g0 = pyro.sample("g0", dist.MultivariateNormal(zero_loc_g0, scale_tril=Lgg0)) g1 = pyro.sample("g1", dist.MultivariateNormal(zero_loc_g1, scale_tril=Lgg1)) g = torch.cat((g0, g1), 0) if y is None: return f else: return self.likelihood(f, g, y)
I don’t understand why the ordering of the variables should matter. There was also a forum post (Ordering of variables affects learning in AR model) which discussed a similar issue, but didn’t provide a solution. Do you guys have any idea why this is the case?