I am currently trying to leverage DKL according to (https://pyro.ai/examples/dkl.html) in an active learning scenario. The task I am working on is a binary classification task therefore I am also using the
Binary likelihood as specified in the tutorial. This all works pretty OK and the classifier is learning.
Since in an active learning scenario I need to add more and more data points to the training set, I would like to use the model to drive the decision which training point to add next. Therefore I use the model to make a prediction on the remaining data points and then select the datapoint where the model is most insecure (has the highest variance). Now as it seems the
Binary likelihood I use returns a vector of zeros and ones when calling
Binary(some_vector). What I need though is a full fledged Bernoulli distribution which gives me a mean and variance for each of my remaining data points. To make that happen, this is the prediction method I came up with:
with torch.no_grad(): for data, target in data_loader: if self.cuda: data, target = data.cuda(), target.cuda() target = target.float() # get prediction of GP model on data f_loc, f_var = self.gpmodule(data) # convert f_loc and f_var into bernoulli distribution # this I copied from the forward method of the Binary likelihood f = torch.sigmoid(dist.Normal(f_loc, f_var.sqrt())()) y_dist = dist.Bernoulli(f) # this I copied from the forward method of the Binary likelihood pred = self.gpmodule.likelihood(f_loc, f_var) # I return the Bernoulli distribution the targets and the predictions from calling the Binary likelihood return y_dist, target, pred
This gives me the desired distribution for all my remaining datapoints
y_dist. Now I use
y_dist to get the mean (
y_dist.mean), the variance (
y_dist.variance) and if I need to the binary predictions
Somehow I think this is doing the trick, but since I am new to this game and framework, I am not entirely sure if this is how its done. Maybe someone a little more experienced can take a short look and tell me if I’m doing it right or if I messed up big time!
Thanks guys …