I’ve built a hierarchical model in Pyro and it works as it should. However, my dateset is imbalanced and I would like to find a way to balance the per-sample loss (to down-weight data points that are over-represented, not to be confused with “sampling” from a distribution). In Scikit-learn this can be achieved like so (where the “wegiths” variable gives my sample importance):
from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X, y, sample_weight=weights)
And in vanilla pytorch it can be done by returning (and subsequently modifying) the per sample loss (not providing the full model here, but I hope you get the point):
def fit(self, X, y, weights, lr=0.05, steps=2000):
"""Fit model to data
"""
pyro.clear_param_store()
# set up optimizer and loss function
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.loss_fn = torch.nn.MSELoss(reduction='none')
for j in range(steps):
# initialize gradients to zero
self.optimizer.zero_grad()
# run the model forward on the data
y_pred = self.model(X).squeeze(-1)
# calculate the mse loss
loss = self.loss_fn(y_pred, y_train)
# for each sample, multiply by sample weight
loss = loss * weights
The key in that snippet is reduction='none'
which gives you back the loss for each X, not the sum.
I’ve been trying to get this to work using SVI and Trace_ELBO for weeks. I’ve searched forums, looked at source code and looked into tutorials and I just can’t find it. The closest I’ve gotten is to modify the log_prob_sum()
method in poutine.trace
, but it seems to return losses for each individual latent variable.
Any ideas on how to solve this would be much appreciated!