loss.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """
  2. Pytorch adaptation of https://omoindrot.github.io/triplet-loss
  3. https://github.com/omoindrot/tensorflow-triplet-loss
  4. """
  5. import torch
  6. import torch.nn as nn
  7. class TripletMarginLoss(nn.Module):
  8. def __init__(self, margin=1.0, p=2.0, mining="batch_all"):
  9. super().__init__()
  10. self.margin = margin
  11. self.p = p
  12. self.mining = mining
  13. if mining == "batch_all":
  14. self.loss_fn = batch_all_triplet_loss
  15. if mining == "batch_hard":
  16. self.loss_fn = batch_hard_triplet_loss
  17. def forward(self, embeddings, labels):
  18. return self.loss_fn(labels, embeddings, self.margin, self.p)
  19. def batch_hard_triplet_loss(labels, embeddings, margin, p):
  20. pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
  21. mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
  22. anchor_positive_dist = mask_anchor_positive * pairwise_dist
  23. # hardest positive for every anchor
  24. hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
  25. mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
  26. # Add max value in each row to invalid negatives
  27. max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
  28. anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
  29. # hardest negative for every anchor
  30. hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
  31. triplet_loss = hardest_positive_dist - hardest_negative_dist + margin
  32. triplet_loss[triplet_loss < 0] = 0
  33. triplet_loss = triplet_loss.mean()
  34. return triplet_loss, -1
  35. def batch_all_triplet_loss(labels, embeddings, margin, p):
  36. pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
  37. anchor_positive_dist = pairwise_dist.unsqueeze(2)
  38. anchor_negative_dist = pairwise_dist.unsqueeze(1)
  39. triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
  40. mask = _get_triplet_mask(labels)
  41. triplet_loss = mask.float() * triplet_loss
  42. # Remove negative losses (easy triplets)
  43. triplet_loss[triplet_loss < 0] = 0
  44. # Count number of positive triplets (where triplet_loss > 0)
  45. valid_triplets = triplet_loss[triplet_loss > 1e-16]
  46. num_positive_triplets = valid_triplets.size(0)
  47. num_valid_triplets = mask.sum()
  48. fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
  49. # Get final mean triplet loss over the positive valid triplets
  50. triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
  51. return triplet_loss, fraction_positive_triplets
  52. def _get_triplet_mask(labels):
  53. # Check that i, j and k are distinct
  54. indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
  55. indices_not_equal = ~indices_equal
  56. i_not_equal_j = indices_not_equal.unsqueeze(2)
  57. i_not_equal_k = indices_not_equal.unsqueeze(1)
  58. j_not_equal_k = indices_not_equal.unsqueeze(0)
  59. distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
  60. label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
  61. i_equal_j = label_equal.unsqueeze(2)
  62. i_equal_k = label_equal.unsqueeze(1)
  63. valid_labels = ~i_equal_k & i_equal_j
  64. return valid_labels & distinct_indices
  65. def _get_anchor_positive_triplet_mask(labels):
  66. # Check that i and j are distinct
  67. indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
  68. indices_not_equal = ~indices_equal
  69. # Check if labels[i] == labels[j]
  70. labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
  71. return labels_equal & indices_not_equal
  72. def _get_anchor_negative_triplet_mask(labels):
  73. return labels.unsqueeze(0) != labels.unsqueeze(1)