_utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. import math
  2. from collections import OrderedDict
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from torch import nn, Tensor
  6. from torch.nn import functional as F
  7. from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
  8. class BalancedPositiveNegativeSampler:
  9. """
  10. This class samples batches, ensuring that they contain a fixed proportion of positives
  11. """
  12. def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
  13. """
  14. Args:
  15. batch_size_per_image (int): number of elements to be selected per image
  16. positive_fraction (float): percentage of positive elements per batch
  17. """
  18. self.batch_size_per_image = batch_size_per_image
  19. self.positive_fraction = positive_fraction
  20. def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
  21. """
  22. Args:
  23. matched_idxs: list of tensors containing -1, 0 or positive values.
  24. Each tensor corresponds to a specific image.
  25. -1 values are ignored, 0 are considered as negatives and > 0 as
  26. positives.
  27. Returns:
  28. pos_idx (list[tensor])
  29. neg_idx (list[tensor])
  30. Returns two lists of binary masks for each image.
  31. The first list contains the positive elements that were selected,
  32. and the second list the negative example.
  33. """
  34. pos_idx = []
  35. neg_idx = []
  36. for matched_idxs_per_image in matched_idxs:
  37. positive = torch.where(matched_idxs_per_image >= 1)[0]
  38. negative = torch.where(matched_idxs_per_image == 0)[0]
  39. num_pos = int(self.batch_size_per_image * self.positive_fraction)
  40. # protect against not enough positive examples
  41. num_pos = min(positive.numel(), num_pos)
  42. num_neg = self.batch_size_per_image - num_pos
  43. # protect against not enough negative examples
  44. num_neg = min(negative.numel(), num_neg)
  45. # randomly select positive and negative examples
  46. perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
  47. perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
  48. pos_idx_per_image = positive[perm1]
  49. neg_idx_per_image = negative[perm2]
  50. # create binary mask from indices
  51. pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
  52. neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
  53. pos_idx_per_image_mask[pos_idx_per_image] = 1
  54. neg_idx_per_image_mask[neg_idx_per_image] = 1
  55. pos_idx.append(pos_idx_per_image_mask)
  56. neg_idx.append(neg_idx_per_image_mask)
  57. return pos_idx, neg_idx
  58. @torch.jit._script_if_tracing
  59. def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
  60. """
  61. Encode a set of proposals with respect to some
  62. reference boxes
  63. Args:
  64. reference_boxes (Tensor): reference boxes
  65. proposals (Tensor): boxes to be encoded
  66. weights (Tensor[4]): the weights for ``(x, y, w, h)``
  67. """
  68. # perform some unpacking to make it JIT-fusion friendly
  69. wx = weights[0]
  70. wy = weights[1]
  71. ww = weights[2]
  72. wh = weights[3]
  73. proposals_x1 = proposals[:, 0].unsqueeze(1)
  74. proposals_y1 = proposals[:, 1].unsqueeze(1)
  75. proposals_x2 = proposals[:, 2].unsqueeze(1)
  76. proposals_y2 = proposals[:, 3].unsqueeze(1)
  77. reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
  78. reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
  79. reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
  80. reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
  81. # implementation starts here
  82. ex_widths = proposals_x2 - proposals_x1
  83. ex_heights = proposals_y2 - proposals_y1
  84. ex_ctr_x = proposals_x1 + 0.5 * ex_widths
  85. ex_ctr_y = proposals_y1 + 0.5 * ex_heights
  86. gt_widths = reference_boxes_x2 - reference_boxes_x1
  87. gt_heights = reference_boxes_y2 - reference_boxes_y1
  88. gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
  89. gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
  90. targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
  91. targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
  92. targets_dw = ww * torch.log(gt_widths / ex_widths)
  93. targets_dh = wh * torch.log(gt_heights / ex_heights)
  94. targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
  95. return targets
  96. class BoxCoder:
  97. """
  98. This class encodes and decodes a set of bounding boxes into
  99. the representation used for training the regressors.
  100. """
  101. def __init__(
  102. self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
  103. ) -> None:
  104. """
  105. Args:
  106. weights (4-element tuple)
  107. bbox_xform_clip (float)
  108. """
  109. self.weights = weights
  110. self.bbox_xform_clip = bbox_xform_clip
  111. def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
  112. boxes_per_image = [len(b) for b in reference_boxes]
  113. reference_boxes = torch.cat(reference_boxes, dim=0)
  114. proposals = torch.cat(proposals, dim=0)
  115. targets = self.encode_single(reference_boxes, proposals)
  116. return targets.split(boxes_per_image, 0)
  117. def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
  118. """
  119. Encode a set of proposals with respect to some
  120. reference boxes
  121. Args:
  122. reference_boxes (Tensor): reference boxes
  123. proposals (Tensor): boxes to be encoded
  124. """
  125. dtype = reference_boxes.dtype
  126. device = reference_boxes.device
  127. weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
  128. targets = encode_boxes(reference_boxes, proposals, weights)
  129. return targets
  130. def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
  131. torch._assert(
  132. isinstance(boxes, (list, tuple)),
  133. "This function expects boxes of type list or tuple.",
  134. )
  135. torch._assert(
  136. isinstance(rel_codes, torch.Tensor),
  137. "This function expects rel_codes of type torch.Tensor.",
  138. )
  139. boxes_per_image = [b.size(0) for b in boxes]
  140. concat_boxes = torch.cat(boxes, dim=0)
  141. box_sum = 0
  142. for val in boxes_per_image:
  143. box_sum += val
  144. if box_sum > 0:
  145. rel_codes = rel_codes.reshape(box_sum, -1)
  146. pred_boxes = self.decode_single(rel_codes, concat_boxes)
  147. if box_sum > 0:
  148. pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
  149. return pred_boxes
  150. def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
  151. """
  152. From a set of original boxes and encoded relative box offsets,
  153. get the decoded boxes.
  154. Args:
  155. rel_codes (Tensor): encoded boxes
  156. boxes (Tensor): reference boxes.
  157. """
  158. boxes = boxes.to(rel_codes.dtype)
  159. widths = boxes[:, 2] - boxes[:, 0]
  160. heights = boxes[:, 3] - boxes[:, 1]
  161. ctr_x = boxes[:, 0] + 0.5 * widths
  162. ctr_y = boxes[:, 1] + 0.5 * heights
  163. wx, wy, ww, wh = self.weights
  164. dx = rel_codes[:, 0::4] / wx
  165. dy = rel_codes[:, 1::4] / wy
  166. dw = rel_codes[:, 2::4] / ww
  167. dh = rel_codes[:, 3::4] / wh
  168. # Prevent sending too large values into torch.exp()
  169. dw = torch.clamp(dw, max=self.bbox_xform_clip)
  170. dh = torch.clamp(dh, max=self.bbox_xform_clip)
  171. pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
  172. pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
  173. pred_w = torch.exp(dw) * widths[:, None]
  174. pred_h = torch.exp(dh) * heights[:, None]
  175. # Distance from center to box's corner.
  176. c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
  177. c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
  178. pred_boxes1 = pred_ctr_x - c_to_c_w
  179. pred_boxes2 = pred_ctr_y - c_to_c_h
  180. pred_boxes3 = pred_ctr_x + c_to_c_w
  181. pred_boxes4 = pred_ctr_y + c_to_c_h
  182. pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
  183. return pred_boxes
  184. class BoxLinearCoder:
  185. """
  186. The linear box-to-box transform defined in FCOS. The transformation is parameterized
  187. by the distance from the center of (square) src box to 4 edges of the target box.
  188. """
  189. def __init__(self, normalize_by_size: bool = True) -> None:
  190. """
  191. Args:
  192. normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
  193. """
  194. self.normalize_by_size = normalize_by_size
  195. def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
  196. """
  197. Encode a set of proposals with respect to some reference boxes
  198. Args:
  199. reference_boxes (Tensor): reference boxes
  200. proposals (Tensor): boxes to be encoded
  201. Returns:
  202. Tensor: the encoded relative box offsets that can be used to
  203. decode the boxes.
  204. """
  205. # get the center of reference_boxes
  206. reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
  207. reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
  208. # get box regression transformation deltas
  209. target_l = reference_boxes_ctr_x - proposals[..., 0]
  210. target_t = reference_boxes_ctr_y - proposals[..., 1]
  211. target_r = proposals[..., 2] - reference_boxes_ctr_x
  212. target_b = proposals[..., 3] - reference_boxes_ctr_y
  213. targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
  214. if self.normalize_by_size:
  215. reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
  216. reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
  217. reference_boxes_size = torch.stack(
  218. (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
  219. )
  220. targets = targets / reference_boxes_size
  221. return targets
  222. def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
  223. """
  224. From a set of original boxes and encoded relative box offsets,
  225. get the decoded boxes.
  226. Args:
  227. rel_codes (Tensor): encoded boxes
  228. boxes (Tensor): reference boxes.
  229. Returns:
  230. Tensor: the predicted boxes with the encoded relative box offsets.
  231. .. note::
  232. This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
  233. """
  234. boxes = boxes.to(dtype=rel_codes.dtype)
  235. ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
  236. ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
  237. if self.normalize_by_size:
  238. boxes_w = boxes[..., 2] - boxes[..., 0]
  239. boxes_h = boxes[..., 3] - boxes[..., 1]
  240. list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
  241. rel_codes = rel_codes * list_box_size
  242. pred_boxes1 = ctr_x - rel_codes[..., 0]
  243. pred_boxes2 = ctr_y - rel_codes[..., 1]
  244. pred_boxes3 = ctr_x + rel_codes[..., 2]
  245. pred_boxes4 = ctr_y + rel_codes[..., 3]
  246. pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
  247. return pred_boxes
  248. class Matcher:
  249. """
  250. This class assigns to each predicted "element" (e.g., a box) a ground-truth
  251. element. Each predicted element will have exactly zero or one matches; each
  252. ground-truth element may be assigned to zero or more predicted elements.
  253. Matching is based on the MxN match_quality_matrix, that characterizes how well
  254. each (ground-truth, predicted)-pair match. For example, if the elements are
  255. boxes, the matrix may contain box IoU overlap values.
  256. The matcher returns a tensor of size N containing the index of the ground-truth
  257. element m that matches to prediction n. If there is no match, a negative value
  258. is returned.
  259. """
  260. BELOW_LOW_THRESHOLD = -1
  261. BETWEEN_THRESHOLDS = -2
  262. __annotations__ = {
  263. "BELOW_LOW_THRESHOLD": int,
  264. "BETWEEN_THRESHOLDS": int,
  265. }
  266. def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
  267. """
  268. Args:
  269. high_threshold (float): quality values greater than or equal to
  270. this value are candidate matches.
  271. low_threshold (float): a lower quality threshold used to stratify
  272. matches into three levels:
  273. 1) matches >= high_threshold
  274. 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
  275. 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
  276. allow_low_quality_matches (bool): if True, produce additional matches
  277. for predictions that have only low-quality match candidates. See
  278. set_low_quality_matches_ for more details.
  279. """
  280. self.BELOW_LOW_THRESHOLD = -1
  281. self.BETWEEN_THRESHOLDS = -2
  282. torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
  283. self.high_threshold = high_threshold
  284. self.low_threshold = low_threshold
  285. self.allow_low_quality_matches = allow_low_quality_matches
  286. def __call__(self, match_quality_matrix: Tensor) -> Tensor:
  287. """
  288. Args:
  289. match_quality_matrix (Tensor[float]): an MxN tensor, containing the
  290. pairwise quality between M ground-truth elements and N predicted elements.
  291. Returns:
  292. matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
  293. [0, M - 1] or a negative value indicating that prediction i could not
  294. be matched.
  295. """
  296. if match_quality_matrix.numel() == 0:
  297. # empty targets or proposals not supported during training
  298. if match_quality_matrix.shape[0] == 0:
  299. raise ValueError("No ground-truth boxes available for one of the images during training")
  300. else:
  301. raise ValueError("No proposal boxes available for one of the images during training")
  302. # match_quality_matrix is M (gt) x N (predicted)
  303. # Max over gt elements (dim 0) to find best gt candidate for each prediction
  304. matched_vals, matches = match_quality_matrix.max(dim=0)
  305. if self.allow_low_quality_matches:
  306. all_matches = matches.clone()
  307. else:
  308. all_matches = None # type: ignore[assignment]
  309. # Assign candidate matches with low quality to negative (unassigned) values
  310. below_low_threshold = matched_vals < self.low_threshold
  311. between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
  312. matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
  313. matches[between_thresholds] = self.BETWEEN_THRESHOLDS
  314. if self.allow_low_quality_matches:
  315. if all_matches is None:
  316. torch._assert(False, "all_matches should not be None")
  317. else:
  318. self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
  319. return matches
  320. def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
  321. """
  322. Produce additional matches for predictions that have only low-quality matches.
  323. Specifically, for each ground-truth find the set of predictions that have
  324. maximum overlap with it (including ties); for each prediction in that set, if
  325. it is unmatched, then match it to the ground-truth with which it has the highest
  326. quality value.
  327. """
  328. # For each gt, find the prediction with which it has the highest quality
  329. highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
  330. # Find the highest quality match available, even if it is low, including ties
  331. gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
  332. # Example gt_pred_pairs_of_highest_quality:
  333. # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
  334. # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
  335. # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
  336. # Note how gt items 1, 2, 3, and 5 each have two ties
  337. pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
  338. matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
  339. class SSDMatcher(Matcher):
  340. def __init__(self, threshold: float) -> None:
  341. super().__init__(threshold, threshold, allow_low_quality_matches=False)
  342. def __call__(self, match_quality_matrix: Tensor) -> Tensor:
  343. matches = super().__call__(match_quality_matrix)
  344. # For each gt, find the prediction with which it has the highest quality
  345. _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
  346. matches[highest_quality_pred_foreach_gt] = torch.arange(
  347. highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
  348. )
  349. return matches
  350. def overwrite_eps(model: nn.Module, eps: float) -> None:
  351. """
  352. This method overwrites the default eps values of all the
  353. FrozenBatchNorm2d layers of the model with the provided value.
  354. This is necessary to address the BC-breaking change introduced
  355. by the bug-fix at pytorch/vision#2933. The overwrite is applied
  356. only when the pretrained weights are loaded to maintain compatibility
  357. with previous versions.
  358. Args:
  359. model (nn.Module): The model on which we perform the overwrite.
  360. eps (float): The new value of eps.
  361. """
  362. for module in model.modules():
  363. if isinstance(module, FrozenBatchNorm2d):
  364. module.eps = eps
  365. def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
  366. """
  367. This method retrieves the number of output channels of a specific model.
  368. Args:
  369. model (nn.Module): The model for which we estimate the out_channels.
  370. It should return a single Tensor or an OrderedDict[Tensor].
  371. size (Tuple[int, int]): The size (wxh) of the input.
  372. Returns:
  373. out_channels (List[int]): A list of the output channels of the model.
  374. """
  375. in_training = model.training
  376. model.eval()
  377. with torch.no_grad():
  378. # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
  379. device = next(model.parameters()).device
  380. tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
  381. features = model(tmp_img)
  382. if isinstance(features, torch.Tensor):
  383. features = OrderedDict([("0", features)])
  384. out_channels = [x.size(1) for x in features.values()]
  385. if in_training:
  386. model.train()
  387. return out_channels
  388. @torch.jit.unused
  389. def _fake_cast_onnx(v: Tensor) -> int:
  390. return v # type: ignore[return-value]
  391. def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
  392. """
  393. ONNX spec requires the k-value to be less than or equal to the number of inputs along
  394. provided dim. Certain models use the number of elements along a particular axis instead of K
  395. if K exceeds the number of elements along that axis. Previously, python's min() function was
  396. used to determine whether to use the provided k-value or the specified dim axis value.
  397. However, in cases where the model is being exported in tracing mode, python min() is
  398. static causing the model to be traced incorrectly and eventually fail at the topk node.
  399. In order to avoid this situation, in tracing mode, torch.min() is used instead.
  400. Args:
  401. input (Tensor): The original input tensor.
  402. orig_kval (int): The provided k-value.
  403. axis(int): Axis along which we retrieve the input size.
  404. Returns:
  405. min_kval (int): Appropriately selected k-value.
  406. """
  407. if not torch.jit.is_tracing():
  408. return min(orig_kval, input.size(axis))
  409. axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
  410. min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
  411. return _fake_cast_onnx(min_kval)
  412. def _box_loss(
  413. type: str,
  414. box_coder: BoxCoder,
  415. anchors_per_image: Tensor,
  416. matched_gt_boxes_per_image: Tensor,
  417. bbox_regression_per_image: Tensor,
  418. cnf: Optional[Dict[str, float]] = None,
  419. ) -> Tensor:
  420. torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
  421. if type == "l1":
  422. target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
  423. return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
  424. elif type == "smooth_l1":
  425. target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
  426. beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
  427. return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
  428. else:
  429. bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
  430. eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
  431. if type == "ciou":
  432. return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
  433. if type == "diou":
  434. return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
  435. # otherwise giou
  436. return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)