Per-sample (training data) loss using svi and trace_elbo

I’ve built a hierarchical model in Pyro and it works as it should. However, my dateset is imbalanced and I would like to find a way to balance the per-sample loss (to down-weight data points that are over-represented, not to be confused with “sampling” from a distribution). In Scikit-learn this can be achieved like so (where the “wegiths” variable gives my sample importance):

from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X, y, sample_weight=weights)

And in vanilla pytorch it can be done by returning (and subsequently modifying) the per sample loss (not providing the full model here, but I hope you get the point):

    def fit(self, X, y, weights, lr=0.05, steps=2000):
        """Fit model to data
        """
        pyro.clear_param_store()

        # set up optimizer and loss function
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        for j in range(steps):
            # initialize gradients to zero
            self.optimizer.zero_grad()
            
            # run the model forward on the data
            y_pred = self.model(X).squeeze(-1)

            # calculate the mse loss
            loss = self.loss_fn(y_pred, y_train)
            
            # for each sample, multiply by sample weight
            loss = loss * weights

The key in that snippet is reduction='none' which gives you back the loss for each X, not the sum.

I’ve been trying to get this to work using SVI and Trace_ELBO for weeks. I’ve searched forums, looked at source code and looked into tutorials and I just can’t find it. The closest I’ve gotten is to modify the log_prob_sum() method in poutine.trace, but it seems to return losses for each individual latent variable.

Any ideas on how to solve this would be much appreciated!

You can use pyro.poutine.scale to multiply log-probability terms by a vector of weights.

For example, if you had a batch of data (X, y) with Y.shape == (N, D) where N is the batch size and you wanted to scale each log-likelihood term by a sample weight weights[i] where weights.shape == (N,), then you might write

def model(X, y, weights):
    ...
    with pyro.plate("data", N):
        ...
        with pyro.poutine.scale(scale=weights):
            obs_dist = compute_my_obs_dist(X, ...)
            pyro.sample("y", obs_dist, obs=y)
1 Like

So simple! I broached the scale method at some point, but I think it was in the context of scaling the loss returned by SVI, which was a single number, so I thought it could not be used as I intended. Added in the with pyro.poutine.scale(scale=weights): snippet to my model and it runs like butter. So far only tested with all weights being 1.0, but I’ll run some additional tests later today. Thanks a million!

1 Like