Dear Pyro users,
I am trying to use DKL in pyro following the tutorial in here. The task is a binary classification.
For the setup:
cnn = classifier()
rbf = gp.kernels.RBF(input_dim=num_features, lengthscale=torch.ones(num_features))
deep_kernel = gp.kernels.Warping(rbf, iwarping_fn=cnn)
Xu = torch.from_numpy(retrieve_inducing_points(X_train, y_train, 128).reshape(-1, 1, ndimension, ndimension).astype(np.float32))
likelihood = gp.likelihoods.Binary()
latent_shape = torch.Size()
gpmodule = gp.models.VariationalSparseGP(X=Xu, y=None, kernel=deep_kernel, Xu=Xu,
num_data=X_train.shape, whiten=True, jitter=2e-4)
optimizer = torch.optim.Adam(gpmodule.parameters(), lr = 0.001)
scheduler = StepLR(optimizer, step_size = 9, gamma = 0.5)
elbo = infer.TraceMeanField_ELBO()
loss_fn = elbo.differentiable_loss
retrieve_inducing_points is simply a function I defined in order to retrieve the inducing points.
The training looks like:
for epoch in tqdm(range(1, epochs + 1)):
train(train_loader, gpmodule, optimizer, loss_fn, epoch)
My issue is that the results change according to whether I run the validation in the training loop, which is a bit strange to me as it should not affect my final test (to my understanding). Does anyone have an idea why it is the case?
Many thanks in advance.