Dealing with Class Imbalance in image segmentation task

I’m performing image segmentation where the output of each image contains different classes with class imbalances. How can I weight the loss of each class differently?

(In the post Deal with class imbalance they suggested using poutine.scale. However their solution requires each minibatch contains only one class. Obviously in image segmentation this can’t be the case as each image can contain pixels of different classes)

Hi @matohak, I’m not familiar with weighted losses in segmentation problems. Could you provide any further details of your problem and the motivation for weighing losses differently? Just guessing, if you have per-pixel training images (hand segmented images) and you want to weigh that training data differently, you can use a pyro.plate over pixels and then poutine.scale where the scale factor is a tensor over pixels, whose weights depend on the known classification. Something like this:

# suppose there are 3 classes with 3 different weights
weights = torch.tensor([0.1, 0.8, 2.0])

def model(input_image, training_data=None):
  prediction = my_neural_net(data)

  batch_size, width, height = input_image.shape
  with pyro.plate("batch", batch_size, dim=-3):
    with pyro.plate("height", height, dim=-1):
      with pyro.plate("width", width, dim=-2):
        if training_data is None:
          scale = 1.0  # don't scale when generating data
        else:
          scale = weights[training_data]
        with poutine.scale(scale=scale):
          pyro.sample("class", dist.Categorical(prediction),
                      obs=training_data)

@fritzo Just my two cents, but in terms of class imbalance I think the OP is talking about the common ML where you have very few examples of some classes in the dataset. So in a supervised training scenario, the model can often get these uncommon class classifications wrong and yet still get really high accuracy. Sometimes the solution to the problem is to adjust the loss function to increase the weight of getting the uncommon class classification correct. In other words, there is a bigger penalty for mis-classification of these minority classes. Sometimes the dataset is pre-processed to ensure that some examples from these minority classes are included in each training batch, etc.

For image segmentation the problem is compounded because there are so many pixels per image, that misclassification of minority classes might get washed out since there are so many pixels in each example.

So that is the context. I am not sure if this description affects any of your subsequent comments.

In my case the my data are medical images with different anatomical structures. Some classes (eg. background) are much more common than others so with the default loss function it would try to predict everything as background hence weighting is needed.

Your code runs fine (thanks!), but I just want to make sure it’s working as I intended it to. My data dimensions are as follows:

“prediction.shape” == batchsize * 5 classes * 256 * 256. The weights should be applied across the class dimension.
“training_data.shape” (the ground truth) == batchsize * 256 pixels * 256 pixels. These are ground truth labels with pixel values {0,1,2,3,4} indicating which one of the 5 channels the prediction should be.
“input_image” would have a dimension of batchsize * 3 channels * 256 * 256

I guess scale is broadcast with obs, is that right?

Also, if we apply plate along width and height, we’d be assuming the pixels are independent. In reality adjacent pixels are more likely to belong to the same class. In that case would it be better not to plate over width and height?

“prediction.shape” == batchsize * 5 classes * 256 * 256

I think for my code snippet you would need batchsize * 256 * 256 * 5, or you could just update the code to match your prediction.shape

I guess scale is broadcast with obs, is that right?

Correct, or more precisely scale is broadcast with obs_dist.log_prob(obs).

In [case of independence] would it be better not to plate over width and height?

It’s valid to use plates because pixels are conditionally independent given upstream nodes, that is, the dependence should be captured by your neural net. I think your intuition is a good criticism of mean field VAEs that don’t capture posterior dependencies among pixels, but you can create guides that do capture posterior correlations among pixels, and is is still valid to model the final observation as conditionally independent.