I’ve built a hierarchical model in Pyro and it works as it should. However, my dateset is imbalanced and I would like to find a way to balance the per-sample loss (to down-weight data points that are over-represented, not to be confused with “sampling” from a distribution). In Scikit-learn this can be achieved like so (where the “wegiths” variable gives my sample importance):
from sklearn.linear_model import LinearRegression reg = LinearRegression().fit(X, y, sample_weight=weights)
And in vanilla pytorch it can be done by returning (and subsequently modifying) the per sample loss (not providing the full model here, but I hope you get the point):
def fit(self, X, y, weights, lr=0.05, steps=2000): """Fit model to data """ pyro.clear_param_store() # set up optimizer and loss function self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.loss_fn = torch.nn.MSELoss(reduction='none') for j in range(steps): # initialize gradients to zero self.optimizer.zero_grad() # run the model forward on the data y_pred = self.model(X).squeeze(-1) # calculate the mse loss loss = self.loss_fn(y_pred, y_train) # for each sample, multiply by sample weight loss = loss * weights
The key in that snippet is
reduction='none' which gives you back the loss for each X, not the sum.
I’ve been trying to get this to work using SVI and Trace_ELBO for weeks. I’ve searched forums, looked at source code and looked into tutorials and I just can’t find it. The closest I’ve gotten is to modify the
log_prob_sum() method in
poutine.trace, but it seems to return losses for each individual latent variable.
Any ideas on how to solve this would be much appreciated!