import torch import torch.nn as nn class DiceLoss(nn.Module): """Dice Loss for segmentation""" def __init__(self, smooth=1.0): super().__init__() self.smooth = smooth def forward(self, inputs, targets): inputs = torch.sigmoid(inputs).view(-1) targets = targets.view(-1).float() intersection = (inputs * targets).sum() dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth) return 1 - dice