ops.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from scipy.optimize import linear_sum_assignment
  6. from ultralytics.utils.metrics import bbox_iou
  7. from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
  8. class HungarianMatcher(nn.Module):
  9. """
  10. A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in
  11. an end-to-end fashion.
  12. HungarianMatcher performs optimal assignment over predicted and ground truth bounding boxes using a cost function
  13. that considers classification scores, bounding box coordinates, and optionally, mask predictions.
  14. Attributes:
  15. cost_gain (dict): Dictionary of cost coefficients for different components: 'class', 'bbox', 'giou', 'mask', and 'dice'.
  16. use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
  17. with_mask (bool): Indicates whether the model makes mask predictions.
  18. num_sample_points (int): The number of sample points used in mask cost calculation.
  19. alpha (float): The alpha factor in Focal Loss calculation.
  20. gamma (float): The gamma factor in Focal Loss calculation.
  21. Methods:
  22. forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the assignment
  23. between predictions and ground truths for a batch.
  24. _cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
  25. """
  26. def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
  27. super().__init__()
  28. if cost_gain is None:
  29. cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
  30. self.cost_gain = cost_gain
  31. self.use_fl = use_fl
  32. self.with_mask = with_mask
  33. self.num_sample_points = num_sample_points
  34. self.alpha = alpha
  35. self.gamma = gamma
  36. def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
  37. """
  38. Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
  39. (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching
  40. between predictions and ground truth based on these costs.
  41. Args:
  42. pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
  43. pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
  44. gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
  45. gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
  46. gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
  47. each image.
  48. masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
  49. Defaults to None.
  50. gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
  51. Defaults to None.
  52. Returns:
  53. (List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
  54. - index_i is the tensor of indices of the selected predictions (in order)
  55. - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
  56. For each batch element, it holds:
  57. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  58. """
  59. bs, nq, nc = pred_scores.shape
  60. if sum(gt_groups) == 0:
  61. return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
  62. # We flatten to compute the cost matrices in a batch
  63. # [batch_size * num_queries, num_classes]
  64. pred_scores = pred_scores.detach().view(-1, nc)
  65. pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
  66. # [batch_size * num_queries, 4]
  67. pred_bboxes = pred_bboxes.detach().view(-1, 4)
  68. # Compute the classification cost
  69. pred_scores = pred_scores[:, gt_cls]
  70. if self.use_fl:
  71. neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
  72. pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
  73. cost_class = pos_cost_class - neg_cost_class
  74. else:
  75. cost_class = -pred_scores
  76. # Compute the L1 cost between boxes
  77. cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
  78. # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
  79. cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
  80. # Final cost matrix
  81. C = self.cost_gain['class'] * cost_class + \
  82. self.cost_gain['bbox'] * cost_bbox + \
  83. self.cost_gain['giou'] * cost_giou
  84. # Compute the mask cost and dice cost
  85. if self.with_mask:
  86. C += self._cost_mask(bs, gt_groups, masks, gt_mask)
  87. C = C.view(bs, nq, -1).cpu()
  88. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
  89. gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
  90. # (idx for queries, idx for gt)
  91. return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
  92. for k, (i, j) in enumerate(indices)]
  93. def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
  94. assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
  95. # all masks share the same set of points for efficient matching
  96. sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
  97. sample_points = 2.0 * sample_points - 1.0
  98. out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
  99. out_mask = out_mask.flatten(0, 1)
  100. tgt_mask = torch.cat(gt_mask).unsqueeze(1)
  101. sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
  102. tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
  103. with torch.cuda.amp.autocast(False):
  104. # binary cross entropy cost
  105. pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
  106. neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
  107. cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
  108. cost_mask /= self.num_sample_points
  109. # dice cost
  110. out_mask = F.sigmoid(out_mask)
  111. numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
  112. denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
  113. cost_dice = 1 - (numerator + 1) / (denominator + 1)
  114. C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
  115. return C
  116. def get_cdn_group(batch,
  117. num_classes,
  118. num_queries,
  119. class_embed,
  120. num_dn=100,
  121. cls_noise_ratio=0.5,
  122. box_noise_scale=1.0,
  123. training=False):
  124. """
  125. Get contrastive denoising training group. This function creates a contrastive denoising training group with
  126. positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding
  127. box coordinates, and returns the modified labels, bounding boxes, attention mask and meta information.
  128. Args:
  129. batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
  130. (torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
  131. indicating the number of gts of each image.
  132. num_classes (int): Number of classes.
  133. num_queries (int): Number of queries.
  134. class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
  135. num_dn (int, optional): Number of denoising. Defaults to 100.
  136. cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
  137. box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
  138. training (bool, optional): If it's in training mode. Defaults to False.
  139. Returns:
  140. (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
  141. bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
  142. is less than or equal to 0, the function returns None for all elements in the tuple.
  143. """
  144. if (not training) or num_dn <= 0:
  145. return None, None, None, None
  146. gt_groups = batch['gt_groups']
  147. total_num = sum(gt_groups)
  148. max_nums = max(gt_groups)
  149. if max_nums == 0:
  150. return None, None, None, None
  151. num_group = num_dn // max_nums
  152. num_group = 1 if num_group == 0 else num_group
  153. # pad gt to max_num of a batch
  154. bs = len(gt_groups)
  155. gt_cls = batch['cls'] # (bs*num, )
  156. gt_bbox = batch['bboxes'] # bs*num, 4
  157. b_idx = batch['batch_idx']
  158. # each group has positive and negative queries.
  159. dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
  160. dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
  161. dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
  162. # positive and negative mask
  163. # (bs*num*num_group, ), the second total_num*num_group part as negative samples
  164. neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
  165. if cls_noise_ratio > 0:
  166. # half of bbox prob
  167. mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
  168. idx = torch.nonzero(mask).squeeze(-1)
  169. # randomly put a new one here
  170. new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
  171. dn_cls[idx] = new_label
  172. if box_noise_scale > 0:
  173. known_bbox = xywh2xyxy(dn_bbox)
  174. diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
  175. rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
  176. rand_part = torch.rand_like(dn_bbox)
  177. rand_part[neg_idx] += 1.0
  178. rand_part *= rand_sign
  179. known_bbox += rand_part * diff
  180. known_bbox.clip_(min=0.0, max=1.0)
  181. dn_bbox = xyxy2xywh(known_bbox)
  182. dn_bbox = inverse_sigmoid(dn_bbox)
  183. # total denoising queries
  184. num_dn = int(max_nums * 2 * num_group)
  185. # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
  186. dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
  187. padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
  188. padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
  189. map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
  190. pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
  191. map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
  192. padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
  193. padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
  194. tgt_size = num_dn + num_queries
  195. attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
  196. # match query cannot see the reconstruct
  197. attn_mask[num_dn:, :num_dn] = True
  198. # reconstruct cannot see each other
  199. for i in range(num_group):
  200. if i == 0:
  201. attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
  202. if i == num_group - 1:
  203. attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
  204. else:
  205. attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
  206. attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
  207. dn_meta = {
  208. 'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
  209. 'dn_num_group': num_group,
  210. 'dn_num_split': [num_dn, num_queries]}
  211. return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
  212. class_embed.device), dn_meta
  213. def inverse_sigmoid(x, eps=1e-6):
  214. """Inverse sigmoid function."""
  215. x = x.clip(min=0., max=1.)
  216. return torch.log(x / (1 - x + eps) + eps)