 # Per-sample (training data) loss using svi and trace_elbo

I’ve built a hierarchical model in Pyro and it works as it should. However, my dateset is imbalanced and I would like to find a way to balance the per-sample loss (to down-weight data points that are over-represented, not to be confused with “sampling” from a distribution). In Scikit-learn this can be achieved like so (where the “wegiths” variable gives my sample importance):

``````from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X, y, sample_weight=weights)
``````

And in vanilla pytorch it can be done by returning (and subsequently modifying) the per sample loss (not providing the full model here, but I hope you get the point):

``````    def fit(self, X, y, weights, lr=0.05, steps=2000):
"""Fit model to data
"""
pyro.clear_param_store()

# set up optimizer and loss function
self.loss_fn = torch.nn.MSELoss(reduction='none')

for j in range(steps):

# run the model forward on the data
y_pred = self.model(X).squeeze(-1)

# calculate the mse loss
loss = self.loss_fn(y_pred, y_train)

# for each sample, multiply by sample weight
loss = loss * weights
``````

The key in that snippet is `reduction='none'` which gives you back the loss for each X, not the sum.

I’ve been trying to get this to work using SVI and Trace_ELBO for weeks. I’ve searched forums, looked at source code and looked into tutorials and I just can’t find it. The closest I’ve gotten is to modify the `log_prob_sum()` method in `poutine.trace`, but it seems to return losses for each individual latent variable.

Any ideas on how to solve this would be much appreciated!

You can use `pyro.poutine.scale` to multiply log-probability terms by a vector of weights.

For example, if you had a batch of data `(X, y)` with `Y.shape == (N, D)` where `N` is the batch size and you wanted to scale each log-likelihood term by a sample weight `weights[i]` where `weights.shape == (N,)`, then you might write

``````def model(X, y, weights):
...
with pyro.plate("data", N):
...
with pyro.poutine.scale(scale=weights):
obs_dist = compute_my_obs_dist(X, ...)
pyro.sample("y", obs_dist, obs=y)
``````
1 Like

So simple! I broached the scale method at some point, but I think it was in the context of scaling the loss returned by SVI, which was a single number, so I thought it could not be used as I intended. Added in the `with pyro.poutine.scale(scale=weights):` snippet to my model and it runs like butter. So far only tested with all weights being 1.0, but I’ll run some additional tests later today. Thanks a million!

1 Like