torchlosses / triplet.py
GenAIDevTOProd's picture
Upload folder using huggingface_hub
4c1e73e verified
raw
history blame contribute delete
303 Bytes
import torch.nn.functional as F
class TripletLoss:
"""Triplet Margin Loss"""
def __init__(self, margin=1.0):
self.margin = margin
def __call__(self, anchor, positive, negative):
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, reduction='mean')