Hi @matohak, I’m not familiar with weighted losses in segmentation problems. Could you provide any further details of your problem and the motivation for weighing losses differently? Just guessing, if you have per-pixel training images (hand segmented images) and you want to weigh that training data differently, you can use a pyro.plate over pixels and then poutine.scale where the scale factor is a tensor over pixels, whose weights depend on the known classification. Something like this:
# suppose there are 3 classes with 3 different weights
weights = torch.tensor([0.1, 0.8, 2.0])
def model(input_image, training_data=None):
prediction = my_neural_net(data)
batch_size, width, height = input_image.shape
with pyro.plate("batch", batch_size, dim=-3):
with pyro.plate("height", height, dim=-1):
with pyro.plate("width", width, dim=-2):
if training_data is None:
scale = 1.0 # don't scale when generating data
else:
scale = weights[training_data]
with poutine.scale(scale=scale):
pyro.sample("class", dist.Categorical(prediction),
obs=training_data)