Hi,
I’m using SVI with a relatively imbalanced dataset (1 positive to 10 negatives). Is there a recommended way to deal with class imbalance in pyro? Ideally I’m looking for something like pos_weight in Torch’s loss functions but I’m not sure whether this makes sense for Trace_ELBO().
Alternatively I’ll over-/under-sample my data.
Thanks.
1 Like
You can pass in a data weight and use poutine.scale:
def model(data, data_scale):
with poutine.scale(scale=data_scale):
...remainder of model as usual...
Then during training you can pass in different minibatches, assuming each minibatch is either all positive or all negative:
SVI = ...
for batch, is_positive in my_data_partitioner():
data_scale = (10. if is_positive else 1.)
svi.step(batch, data_scale)
1 Like