| import torch | |
| import torch.nn as nn | |
| class HuberLoss(nn.Module): | |
| """Huber Loss (a.k.a. Smooth L1)""" | |
| def __init__(self, delta=1.0, reduction='mean'): | |
| super().__init__() | |
| self.delta = delta | |
| self.reduction = reduction | |
| def forward(self, inputs, targets): | |
| diff = torch.abs(inputs - targets) | |
| loss = torch.where(diff < self.delta, 0.5 * diff**2, self.delta * (diff - 0.5 * self.delta)) | |
| return loss.mean() if self.reduction == 'mean' else loss.sum() | |