import math from collections import OrderedDict from typing import Dict, List, Optional, Tuple import torch from torch import nn, Tensor from torch.nn import functional as F from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss class BalancedPositiveNegativeSampler: """ This class samples batches, ensuring that they contain a fixed proportion of positives """ def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None: """ Args: batch_size_per_image (int): number of elements to be selected per image positive_fraction (float): percentage of positive elements per batch """ self.batch_size_per_image = batch_size_per_image self.positive_fraction = positive_fraction def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: """ Args: matched_idxs: list of tensors containing -1, 0 or positive values. Each tensor corresponds to a specific image. -1 values are ignored, 0 are considered as negatives and > 0 as positives. Returns: pos_idx (list[tensor]) neg_idx (list[tensor]) Returns two lists of binary masks for each image. The first list contains the positive elements that were selected, and the second list the negative example. """ pos_idx = [] neg_idx = [] for matched_idxs_per_image in matched_idxs: positive = torch.where(matched_idxs_per_image >= 1)[0] negative = torch.where(matched_idxs_per_image == 0)[0] num_pos = int(self.batch_size_per_image * self.positive_fraction) # protect against not enough positive examples num_pos = min(positive.numel(), num_pos) num_neg = self.batch_size_per_image - num_pos # protect against not enough negative examples num_neg = min(negative.numel(), num_neg) # randomly select positive and negative examples perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] pos_idx_per_image = positive[perm1] neg_idx_per_image = negative[perm2] # create binary mask from indices pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) pos_idx_per_image_mask[pos_idx_per_image] = 1 neg_idx_per_image_mask[neg_idx_per_image] = 1 pos_idx.append(pos_idx_per_image_mask) neg_idx.append(neg_idx_per_image_mask) return pos_idx, neg_idx @torch.jit._script_if_tracing def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor: """ Encode a set of proposals with respect to some reference boxes Args: reference_boxes (Tensor): reference boxes proposals (Tensor): boxes to be encoded weights (Tensor[4]): the weights for ``(x, y, w, h)`` """ # perform some unpacking to make it JIT-fusion friendly wx = weights[0] wy = weights[1] ww = weights[2] wh = weights[3] proposals_x1 = proposals[:, 0].unsqueeze(1) proposals_y1 = proposals[:, 1].unsqueeze(1) proposals_x2 = proposals[:, 2].unsqueeze(1) proposals_y2 = proposals[:, 3].unsqueeze(1) reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1) reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1) reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1) reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1) # implementation starts here ex_widths = proposals_x2 - proposals_x1 ex_heights = proposals_y2 - proposals_y1 ex_ctr_x = proposals_x1 + 0.5 * ex_widths ex_ctr_y = proposals_y1 + 0.5 * ex_heights gt_widths = reference_boxes_x2 - reference_boxes_x1 gt_heights = reference_boxes_y2 - reference_boxes_y1 gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights targets_dw = ww * torch.log(gt_widths / ex_widths) targets_dh = wh * torch.log(gt_heights / ex_heights) targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) return targets class BoxCoder: """ This class encodes and decodes a set of bounding boxes into the representation used for training the regressors. """ def __init__( self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16) ) -> None: """ Args: weights (4-element tuple) bbox_xform_clip (float) """ self.weights = weights self.bbox_xform_clip = bbox_xform_clip def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]: boxes_per_image = [len(b) for b in reference_boxes] reference_boxes = torch.cat(reference_boxes, dim=0) proposals = torch.cat(proposals, dim=0) targets = self.encode_single(reference_boxes, proposals) return targets.split(boxes_per_image, 0) def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: """ Encode a set of proposals with respect to some reference boxes Args: reference_boxes (Tensor): reference boxes proposals (Tensor): boxes to be encoded """ dtype = reference_boxes.dtype device = reference_boxes.device weights = torch.as_tensor(self.weights, dtype=dtype, device=device) targets = encode_boxes(reference_boxes, proposals, weights) return targets def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: torch._assert( isinstance(boxes, (list, tuple)), "This function expects boxes of type list or tuple.", ) torch._assert( isinstance(rel_codes, torch.Tensor), "This function expects rel_codes of type torch.Tensor.", ) boxes_per_image = [b.size(0) for b in boxes] concat_boxes = torch.cat(boxes, dim=0) box_sum = 0 for val in boxes_per_image: box_sum += val if box_sum > 0: rel_codes = rel_codes.reshape(box_sum, -1) pred_boxes = self.decode_single(rel_codes, concat_boxes) if box_sum > 0: pred_boxes = pred_boxes.reshape(box_sum, -1, 4) return pred_boxes def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: """ From a set of original boxes and encoded relative box offsets, get the decoded boxes. Args: rel_codes (Tensor): encoded boxes boxes (Tensor): reference boxes. """ boxes = boxes.to(rel_codes.dtype) widths = boxes[:, 2] - boxes[:, 0] heights = boxes[:, 3] - boxes[:, 1] ctr_x = boxes[:, 0] + 0.5 * widths ctr_y = boxes[:, 1] + 0.5 * heights wx, wy, ww, wh = self.weights dx = rel_codes[:, 0::4] / wx dy = rel_codes[:, 1::4] / wy dw = rel_codes[:, 2::4] / ww dh = rel_codes[:, 3::4] / wh # Prevent sending too large values into torch.exp() dw = torch.clamp(dw, max=self.bbox_xform_clip) dh = torch.clamp(dh, max=self.bbox_xform_clip) pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] pred_w = torch.exp(dw) * widths[:, None] pred_h = torch.exp(dh) * heights[:, None] # Distance from center to box's corner. c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w pred_boxes1 = pred_ctr_x - c_to_c_w pred_boxes2 = pred_ctr_y - c_to_c_h pred_boxes3 = pred_ctr_x + c_to_c_w pred_boxes4 = pred_ctr_y + c_to_c_h pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) return pred_boxes class BoxLinearCoder: """ The linear box-to-box transform defined in FCOS. The transformation is parameterized by the distance from the center of (square) src box to 4 edges of the target box. """ def __init__(self, normalize_by_size: bool = True) -> None: """ Args: normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes. """ self.normalize_by_size = normalize_by_size def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: """ Encode a set of proposals with respect to some reference boxes Args: reference_boxes (Tensor): reference boxes proposals (Tensor): boxes to be encoded Returns: Tensor: the encoded relative box offsets that can be used to decode the boxes. """ # get the center of reference_boxes reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2]) reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3]) # get box regression transformation deltas target_l = reference_boxes_ctr_x - proposals[..., 0] target_t = reference_boxes_ctr_y - proposals[..., 1] target_r = proposals[..., 2] - reference_boxes_ctr_x target_b = proposals[..., 3] - reference_boxes_ctr_y targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1) if self.normalize_by_size: reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0] reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1] reference_boxes_size = torch.stack( (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1 ) targets = targets / reference_boxes_size return targets def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: """ From a set of original boxes and encoded relative box offsets, get the decoded boxes. Args: rel_codes (Tensor): encoded boxes boxes (Tensor): reference boxes. Returns: Tensor: the predicted boxes with the encoded relative box offsets. .. note:: This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``. """ boxes = boxes.to(dtype=rel_codes.dtype) ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2]) ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3]) if self.normalize_by_size: boxes_w = boxes[..., 2] - boxes[..., 0] boxes_h = boxes[..., 3] - boxes[..., 1] list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1) rel_codes = rel_codes * list_box_size pred_boxes1 = ctr_x - rel_codes[..., 0] pred_boxes2 = ctr_y - rel_codes[..., 1] pred_boxes3 = ctr_x + rel_codes[..., 2] pred_boxes4 = ctr_y + rel_codes[..., 3] pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1) return pred_boxes class Matcher: """ This class assigns to each predicted "element" (e.g., a box) a ground-truth element. Each predicted element will have exactly zero or one matches; each ground-truth element may be assigned to zero or more predicted elements. Matching is based on the MxN match_quality_matrix, that characterizes how well each (ground-truth, predicted)-pair match. For example, if the elements are boxes, the matrix may contain box IoU overlap values. The matcher returns a tensor of size N containing the index of the ground-truth element m that matches to prediction n. If there is no match, a negative value is returned. """ BELOW_LOW_THRESHOLD = -1 BETWEEN_THRESHOLDS = -2 __annotations__ = { "BELOW_LOW_THRESHOLD": int, "BETWEEN_THRESHOLDS": int, } def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None: """ Args: high_threshold (float): quality values greater than or equal to this value are candidate matches. low_threshold (float): a lower quality threshold used to stratify matches into three levels: 1) matches >= high_threshold 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) allow_low_quality_matches (bool): if True, produce additional matches for predictions that have only low-quality match candidates. See set_low_quality_matches_ for more details. """ self.BELOW_LOW_THRESHOLD = -1 self.BETWEEN_THRESHOLDS = -2 torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold") self.high_threshold = high_threshold self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches def __call__(self, match_quality_matrix: Tensor) -> Tensor: """ Args: match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted elements. Returns: matches (Tensor[int64]): an N tensor where N[i] is a matched gt in [0, M - 1] or a negative value indicating that prediction i could not be matched. """ if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training if match_quality_matrix.shape[0] == 0: raise ValueError("No ground-truth boxes available for one of the images during training") else: raise ValueError("No proposal boxes available for one of the images during training") # match_quality_matrix is M (gt) x N (predicted) # Max over gt elements (dim 0) to find best gt candidate for each prediction matched_vals, matches = match_quality_matrix.max(dim=0) if self.allow_low_quality_matches: all_matches = matches.clone() else: all_matches = None # type: ignore[assignment] # Assign candidate matches with low quality to negative (unassigned) values below_low_threshold = matched_vals < self.low_threshold between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold) matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD matches[between_thresholds] = self.BETWEEN_THRESHOLDS if self.allow_low_quality_matches: if all_matches is None: torch._assert(False, "all_matches should not be None") else: self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) return matches def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None: """ Produce additional matches for predictions that have only low-quality matches. Specifically, for each ground-truth find the set of predictions that have maximum overlap with it (including ties); for each prediction in that set, if it is unmatched, then match it to the ground-truth with which it has the highest quality value. """ # For each gt, find the prediction with which it has the highest quality highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # Find the highest quality match available, even if it is low, including ties gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None]) # Example gt_pred_pairs_of_highest_quality: # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]), # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390])) # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index # Note how gt items 1, 2, 3, and 5 each have two ties pred_inds_to_update = gt_pred_pairs_of_highest_quality[1] matches[pred_inds_to_update] = all_matches[pred_inds_to_update] class SSDMatcher(Matcher): def __init__(self, threshold: float) -> None: super().__init__(threshold, threshold, allow_low_quality_matches=False) def __call__(self, match_quality_matrix: Tensor) -> Tensor: matches = super().__call__(match_quality_matrix) # For each gt, find the prediction with which it has the highest quality _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1) matches[highest_quality_pred_foreach_gt] = torch.arange( highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device ) return matches def overwrite_eps(model: nn.Module, eps: float) -> None: """ This method overwrites the default eps values of all the FrozenBatchNorm2d layers of the model with the provided value. This is necessary to address the BC-breaking change introduced by the bug-fix at pytorch/vision#2933. The overwrite is applied only when the pretrained weights are loaded to maintain compatibility with previous versions. Args: model (nn.Module): The model on which we perform the overwrite. eps (float): The new value of eps. """ for module in model.modules(): if isinstance(module, FrozenBatchNorm2d): module.eps = eps def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]: """ This method retrieves the number of output channels of a specific model. Args: model (nn.Module): The model for which we estimate the out_channels. It should return a single Tensor or an OrderedDict[Tensor]. size (Tuple[int, int]): The size (wxh) of the input. Returns: out_channels (List[int]): A list of the output channels of the model. """ in_training = model.training model.eval() with torch.no_grad(): # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values device = next(model.parameters()).device tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device) features = model(tmp_img) if isinstance(features, torch.Tensor): features = OrderedDict([("0", features)]) out_channels = [x.size(1) for x in features.values()] if in_training: model.train() return out_channels @torch.jit.unused def _fake_cast_onnx(v: Tensor) -> int: return v # type: ignore[return-value] def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int: """ ONNX spec requires the k-value to be less than or equal to the number of inputs along provided dim. Certain models use the number of elements along a particular axis instead of K if K exceeds the number of elements along that axis. Previously, python's min() function was used to determine whether to use the provided k-value or the specified dim axis value. However, in cases where the model is being exported in tracing mode, python min() is static causing the model to be traced incorrectly and eventually fail at the topk node. In order to avoid this situation, in tracing mode, torch.min() is used instead. Args: input (Tensor): The original input tensor. orig_kval (int): The provided k-value. axis(int): Axis along which we retrieve the input size. Returns: min_kval (int): Appropriately selected k-value. """ if not torch.jit.is_tracing(): return min(orig_kval, input.size(axis)) axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0) min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0)) return _fake_cast_onnx(min_kval) def _box_loss( type: str, box_coder: BoxCoder, anchors_per_image: Tensor, matched_gt_boxes_per_image: Tensor, bbox_regression_per_image: Tensor, cnf: Optional[Dict[str, float]] = None, ) -> Tensor: torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}") if type == "l1": target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") elif type == "smooth_l1": target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0 return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta) else: bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image) eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7 if type == "ciou": return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) if type == "diou": return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps) # otherwise giou return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)