""" Pytorch adaptation of https://omoindrot.github.io/triplet-loss https://github.com/omoindrot/tensorflow-triplet-loss """ import torch import torch.nn as nn class TripletMarginLoss(nn.Module): def __init__(self, margin=1.0, p=2.0, mining="batch_all"): super().__init__() self.margin = margin self.p = p self.mining = mining if mining == "batch_all": self.loss_fn = batch_all_triplet_loss if mining == "batch_hard": self.loss_fn = batch_hard_triplet_loss def forward(self, embeddings, labels): return self.loss_fn(labels, embeddings, self.margin, self.p) def batch_hard_triplet_loss(labels, embeddings, margin, p): pairwise_dist = torch.cdist(embeddings, embeddings, p=p) mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() anchor_positive_dist = mask_anchor_positive * pairwise_dist # hardest positive for every anchor hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() # Add max value in each row to invalid negatives max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) # hardest negative for every anchor hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) triplet_loss = hardest_positive_dist - hardest_negative_dist + margin triplet_loss[triplet_loss < 0] = 0 triplet_loss = triplet_loss.mean() return triplet_loss, -1 def batch_all_triplet_loss(labels, embeddings, margin, p): pairwise_dist = torch.cdist(embeddings, embeddings, p=p) anchor_positive_dist = pairwise_dist.unsqueeze(2) anchor_negative_dist = pairwise_dist.unsqueeze(1) triplet_loss = anchor_positive_dist - anchor_negative_dist + margin mask = _get_triplet_mask(labels) triplet_loss = mask.float() * triplet_loss # Remove negative losses (easy triplets) triplet_loss[triplet_loss < 0] = 0 # Count number of positive triplets (where triplet_loss > 0) valid_triplets = triplet_loss[triplet_loss > 1e-16] num_positive_triplets = valid_triplets.size(0) num_valid_triplets = mask.sum() fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) # Get final mean triplet loss over the positive valid triplets triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) return triplet_loss, fraction_positive_triplets def _get_triplet_mask(labels): # Check that i, j and k are distinct indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device) indices_not_equal = ~indices_equal i_not_equal_j = indices_not_equal.unsqueeze(2) i_not_equal_k = indices_not_equal.unsqueeze(1) j_not_equal_k = indices_not_equal.unsqueeze(0) distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) i_equal_j = label_equal.unsqueeze(2) i_equal_k = label_equal.unsqueeze(1) valid_labels = ~i_equal_k & i_equal_j return valid_labels & distinct_indices def _get_anchor_positive_triplet_mask(labels): # Check that i and j are distinct indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device) indices_not_equal = ~indices_equal # Check if labels[i] == labels[j] labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) return labels_equal & indices_not_equal def _get_anchor_negative_triplet_mask(labels): return labels.unsqueeze(0) != labels.unsqueeze(1)