import torch import torch.nn as nn import torch.nn.functional as F class ContrastiveLoss(nn.Module): """Contrastive Loss for Siamese networks""" def __init__(self, margin=1.0): super().__init__() self.margin = margin def forward(self, x1, x2, label): dist = F.pairwise_distance(x1, x2) loss = (1 - label) * torch.pow(dist, 2) + label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2) return loss.mean()