Deal with class imbalance

Hi,
I’m using SVI with a relatively imbalanced dataset (1 positive to 10 negatives). Is there a recommended way to deal with class imbalance in pyro? Ideally I’m looking for something like pos_weight in Torch’s loss functions but I’m not sure whether this makes sense for Trace_ELBO().
Alternatively I’ll over-/under-sample my data.
Thanks.

1 Like

You can pass in a data weight and use poutine.scale:

def model(data, data_scale):
    with poutine.scale(scale=data_scale):
        ...remainder of model as usual...

Then during training you can pass in different minibatches, assuming each minibatch is either all positive or all negative:

SVI = ...
for batch, is_positive in my_data_partitioner():
    data_scale = (10. if is_positive else 1.)
    svi.step(batch, data_scale)
1 Like