123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- """
- Pytorch adaptation of https://omoindrot.github.io/triplet-loss
- https://github.com/omoindrot/tensorflow-triplet-loss
- """
- import torch
- import torch.nn as nn
- class TripletMarginLoss(nn.Module):
- def __init__(self, margin=1.0, p=2.0, mining="batch_all"):
- super().__init__()
- self.margin = margin
- self.p = p
- self.mining = mining
- if mining == "batch_all":
- self.loss_fn = batch_all_triplet_loss
- if mining == "batch_hard":
- self.loss_fn = batch_hard_triplet_loss
- def forward(self, embeddings, labels):
- return self.loss_fn(labels, embeddings, self.margin, self.p)
- def batch_hard_triplet_loss(labels, embeddings, margin, p):
- pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
- mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
- anchor_positive_dist = mask_anchor_positive * pairwise_dist
- # hardest positive for every anchor
- hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
- mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
- # Add max value in each row to invalid negatives
- max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
- anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
- # hardest negative for every anchor
- hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
- triplet_loss = hardest_positive_dist - hardest_negative_dist + margin
- triplet_loss[triplet_loss < 0] = 0
- triplet_loss = triplet_loss.mean()
- return triplet_loss, -1
- def batch_all_triplet_loss(labels, embeddings, margin, p):
- pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
- anchor_positive_dist = pairwise_dist.unsqueeze(2)
- anchor_negative_dist = pairwise_dist.unsqueeze(1)
- triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
- mask = _get_triplet_mask(labels)
- triplet_loss = mask.float() * triplet_loss
- # Remove negative losses (easy triplets)
- triplet_loss[triplet_loss < 0] = 0
- # Count number of positive triplets (where triplet_loss > 0)
- valid_triplets = triplet_loss[triplet_loss > 1e-16]
- num_positive_triplets = valid_triplets.size(0)
- num_valid_triplets = mask.sum()
- fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
- # Get final mean triplet loss over the positive valid triplets
- triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
- return triplet_loss, fraction_positive_triplets
- def _get_triplet_mask(labels):
- # Check that i, j and k are distinct
- indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
- indices_not_equal = ~indices_equal
- i_not_equal_j = indices_not_equal.unsqueeze(2)
- i_not_equal_k = indices_not_equal.unsqueeze(1)
- j_not_equal_k = indices_not_equal.unsqueeze(0)
- distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
- label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
- i_equal_j = label_equal.unsqueeze(2)
- i_equal_k = label_equal.unsqueeze(1)
- valid_labels = ~i_equal_k & i_equal_j
- return valid_labels & distinct_indices
- def _get_anchor_positive_triplet_mask(labels):
- # Check that i and j are distinct
- indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
- indices_not_equal = ~indices_equal
- # Check if labels[i] == labels[j]
- labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
- return labels_equal & indices_not_equal
- def _get_anchor_negative_triplet_mask(labels):
- return labels.unsqueeze(0) != labels.unsqueeze(1)
|