Dealing with Class Imbalance in image segmentation task

Hi @matohak, I’m not familiar with weighted losses in segmentation problems. Could you provide any further details of your problem and the motivation for weighing losses differently? Just guessing, if you have per-pixel training images (hand segmented images) and you want to weigh that training data differently, you can use a pyro.plate over pixels and then poutine.scale where the scale factor is a tensor over pixels, whose weights depend on the known classification. Something like this:

# suppose there are 3 classes with 3 different weights
weights = torch.tensor([0.1, 0.8, 2.0])

def model(input_image, training_data=None):
  prediction = my_neural_net(data)

  batch_size, width, height = input_image.shape
  with pyro.plate("batch", batch_size, dim=-3):
    with pyro.plate("height", height, dim=-1):
      with pyro.plate("width", width, dim=-2):
        if training_data is None:
          scale = 1.0  # don't scale when generating data
        else:
          scale = weights[training_data]
        with poutine.scale(scale=scale):
          pyro.sample("class", dist.Categorical(prediction),
                      obs=training_data)