loss.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ultralytics.utils.loss import FocalLoss, VarifocalLoss
  6. from ultralytics.utils.metrics import bbox_iou
  7. from .ops import HungarianMatcher
  8. class DETRLoss(nn.Module):
  9. def __init__(self,
  10. nc=80,
  11. loss_gain=None,
  12. aux_loss=True,
  13. use_fl=True,
  14. use_vfl=False,
  15. use_uni_match=False,
  16. uni_match_ind=0):
  17. """
  18. DETR loss function.
  19. Args:
  20. nc (int): The number of classes.
  21. loss_gain (dict): The coefficient of loss.
  22. aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
  23. use_vfl (bool): Use VarifocalLoss or not.
  24. use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
  25. uni_match_ind (int): The fixed indices of a layer.
  26. """
  27. super().__init__()
  28. if loss_gain is None:
  29. loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
  30. self.nc = nc
  31. self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
  32. self.loss_gain = loss_gain
  33. self.aux_loss = aux_loss
  34. self.fl = FocalLoss() if use_fl else None
  35. self.vfl = VarifocalLoss() if use_vfl else None
  36. self.use_uni_match = use_uni_match
  37. self.uni_match_ind = uni_match_ind
  38. self.device = None
  39. def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
  40. # logits: [b, query, num_classes], gt_class: list[[n, 1]]
  41. name_class = f'loss_class{postfix}'
  42. bs, nq = pred_scores.shape[:2]
  43. # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
  44. one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
  45. one_hot.scatter_(2, targets.unsqueeze(-1), 1)
  46. one_hot = one_hot[..., :-1]
  47. gt_scores = gt_scores.view(bs, nq, 1) * one_hot
  48. if self.fl:
  49. if num_gts and self.vfl:
  50. loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
  51. else:
  52. loss_cls = self.fl(pred_scores, one_hot.float())
  53. loss_cls /= max(num_gts, 1) / nq
  54. else:
  55. loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
  56. return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
  57. def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
  58. # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
  59. name_bbox = f'loss_bbox{postfix}'
  60. name_giou = f'loss_giou{postfix}'
  61. loss = {}
  62. if len(gt_bboxes) == 0:
  63. loss[name_bbox] = torch.tensor(0., device=self.device)
  64. loss[name_giou] = torch.tensor(0., device=self.device)
  65. return loss
  66. loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
  67. loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
  68. loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
  69. loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
  70. loss = {k: v.squeeze() for k, v in loss.items()}
  71. return loss
  72. def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
  73. # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
  74. name_mask = f'loss_mask{postfix}'
  75. name_dice = f'loss_dice{postfix}'
  76. loss = {}
  77. if sum(len(a) for a in gt_mask) == 0:
  78. loss[name_mask] = torch.tensor(0., device=self.device)
  79. loss[name_dice] = torch.tensor(0., device=self.device)
  80. return loss
  81. num_gts = len(gt_mask)
  82. src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
  83. src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
  84. # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
  85. loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
  86. torch.tensor([num_gts], dtype=torch.float32))
  87. loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
  88. return loss
  89. def _dice_loss(self, inputs, targets, num_gts):
  90. inputs = F.sigmoid(inputs)
  91. inputs = inputs.flatten(1)
  92. targets = targets.flatten(1)
  93. numerator = 2 * (inputs * targets).sum(1)
  94. denominator = inputs.sum(-1) + targets.sum(-1)
  95. loss = 1 - (numerator + 1) / (denominator + 1)
  96. return loss.sum() / num_gts
  97. def _get_loss_aux(self,
  98. pred_bboxes,
  99. pred_scores,
  100. gt_bboxes,
  101. gt_cls,
  102. gt_groups,
  103. match_indices=None,
  104. postfix='',
  105. masks=None,
  106. gt_mask=None):
  107. """Get auxiliary losses"""
  108. # NOTE: loss class, bbox, giou, mask, dice
  109. loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
  110. if match_indices is None and self.use_uni_match:
  111. match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
  112. pred_scores[self.uni_match_ind],
  113. gt_bboxes,
  114. gt_cls,
  115. gt_groups,
  116. masks=masks[self.uni_match_ind] if masks is not None else None,
  117. gt_mask=gt_mask)
  118. for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
  119. aux_masks = masks[i] if masks is not None else None
  120. loss_ = self._get_loss(aux_bboxes,
  121. aux_scores,
  122. gt_bboxes,
  123. gt_cls,
  124. gt_groups,
  125. masks=aux_masks,
  126. gt_mask=gt_mask,
  127. postfix=postfix,
  128. match_indices=match_indices)
  129. loss[0] += loss_[f'loss_class{postfix}']
  130. loss[1] += loss_[f'loss_bbox{postfix}']
  131. loss[2] += loss_[f'loss_giou{postfix}']
  132. # if masks is not None and gt_mask is not None:
  133. # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
  134. # loss[3] += loss_[f'loss_mask{postfix}']
  135. # loss[4] += loss_[f'loss_dice{postfix}']
  136. loss = {
  137. f'loss_class_aux{postfix}': loss[0],
  138. f'loss_bbox_aux{postfix}': loss[1],
  139. f'loss_giou_aux{postfix}': loss[2]}
  140. # if masks is not None and gt_mask is not None:
  141. # loss[f'loss_mask_aux{postfix}'] = loss[3]
  142. # loss[f'loss_dice_aux{postfix}'] = loss[4]
  143. return loss
  144. def _get_index(self, match_indices):
  145. batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
  146. src_idx = torch.cat([src for (src, _) in match_indices])
  147. dst_idx = torch.cat([dst for (_, dst) in match_indices])
  148. return (batch_idx, src_idx), dst_idx
  149. def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
  150. pred_assigned = torch.cat([
  151. t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
  152. for t, (I, _) in zip(pred_bboxes, match_indices)])
  153. gt_assigned = torch.cat([
  154. t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
  155. for t, (_, J) in zip(gt_bboxes, match_indices)])
  156. return pred_assigned, gt_assigned
  157. def _get_loss(self,
  158. pred_bboxes,
  159. pred_scores,
  160. gt_bboxes,
  161. gt_cls,
  162. gt_groups,
  163. masks=None,
  164. gt_mask=None,
  165. postfix='',
  166. match_indices=None):
  167. """Get losses"""
  168. if match_indices is None:
  169. match_indices = self.matcher(pred_bboxes,
  170. pred_scores,
  171. gt_bboxes,
  172. gt_cls,
  173. gt_groups,
  174. masks=masks,
  175. gt_mask=gt_mask)
  176. idx, gt_idx = self._get_index(match_indices)
  177. pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
  178. bs, nq = pred_scores.shape[:2]
  179. targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
  180. targets[idx] = gt_cls[gt_idx]
  181. gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
  182. if len(gt_bboxes):
  183. gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
  184. loss = {}
  185. loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
  186. loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
  187. # if masks is not None and gt_mask is not None:
  188. # loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
  189. return loss
  190. def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
  191. """
  192. Args:
  193. pred_bboxes (torch.Tensor): [l, b, query, 4]
  194. pred_scores (torch.Tensor): [l, b, query, num_classes]
  195. batch (dict): A dict includes:
  196. gt_cls (torch.Tensor) with shape [num_gts, ],
  197. gt_bboxes (torch.Tensor): [num_gts, 4],
  198. gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
  199. postfix (str): postfix of loss name.
  200. """
  201. self.device = pred_bboxes.device
  202. match_indices = kwargs.get('match_indices', None)
  203. gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
  204. total_loss = self._get_loss(pred_bboxes[-1],
  205. pred_scores[-1],
  206. gt_bboxes,
  207. gt_cls,
  208. gt_groups,
  209. postfix=postfix,
  210. match_indices=match_indices)
  211. if self.aux_loss:
  212. total_loss.update(
  213. self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
  214. postfix))
  215. return total_loss
  216. class RTDETRDetectionLoss(DETRLoss):
  217. def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
  218. pred_bboxes, pred_scores = preds
  219. total_loss = super().forward(pred_bboxes, pred_scores, batch)
  220. if dn_meta is not None:
  221. dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
  222. assert len(batch['gt_groups']) == len(dn_pos_idx)
  223. # denoising match indices
  224. match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
  225. # compute denoising training loss
  226. dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
  227. total_loss.update(dn_loss)
  228. else:
  229. total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
  230. return total_loss
  231. @staticmethod
  232. def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
  233. """Get the match indices for denoising.
  234. Args:
  235. dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
  236. dn_num_group (int): The number of groups of denoising.
  237. gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
  238. Returns:
  239. dn_match_indices (List(tuple)): Matched indices.
  240. """
  241. dn_match_indices = []
  242. idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
  243. for i, num_gt in enumerate(gt_groups):
  244. if num_gt > 0:
  245. gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
  246. gt_idx = gt_idx.repeat(dn_num_group)
  247. assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
  248. f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
  249. dn_match_indices.append((dn_pos_idx[i], gt_idx))
  250. else:
  251. dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
  252. return dn_match_indices