| import torch.nn.functional as F | |
| class CosineEmbeddingLoss: | |
| """Cosine Embedding Loss for similarity learning""" | |
| def __init__(self, margin=0.0): | |
| self.margin = margin | |
| def __call__(self, x1, x2, label): | |
| return F.cosine_embedding_loss(x1, x2, label, margin=self.margin) | |