123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540 |
- import math
- from collections import OrderedDict
- from typing import Dict, List, Optional, Tuple
- import torch
- from torch import nn, Tensor
- from torch.nn import functional as F
- from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
- class BalancedPositiveNegativeSampler:
- """
- This class samples batches, ensuring that they contain a fixed proportion of positives
- """
- def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
- """
- Args:
- batch_size_per_image (int): number of elements to be selected per image
- positive_fraction (float): percentage of positive elements per batch
- """
- self.batch_size_per_image = batch_size_per_image
- self.positive_fraction = positive_fraction
- def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
- """
- Args:
- matched_idxs: list of tensors containing -1, 0 or positive values.
- Each tensor corresponds to a specific image.
- -1 values are ignored, 0 are considered as negatives and > 0 as
- positives.
- Returns:
- pos_idx (list[tensor])
- neg_idx (list[tensor])
- Returns two lists of binary masks for each image.
- The first list contains the positive elements that were selected,
- and the second list the negative example.
- """
- pos_idx = []
- neg_idx = []
- for matched_idxs_per_image in matched_idxs:
- positive = torch.where(matched_idxs_per_image >= 1)[0]
- negative = torch.where(matched_idxs_per_image == 0)[0]
- num_pos = int(self.batch_size_per_image * self.positive_fraction)
- # protect against not enough positive examples
- num_pos = min(positive.numel(), num_pos)
- num_neg = self.batch_size_per_image - num_pos
- # protect against not enough negative examples
- num_neg = min(negative.numel(), num_neg)
- # randomly select positive and negative examples
- perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
- perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
- pos_idx_per_image = positive[perm1]
- neg_idx_per_image = negative[perm2]
- # create binary mask from indices
- pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
- neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
- pos_idx_per_image_mask[pos_idx_per_image] = 1
- neg_idx_per_image_mask[neg_idx_per_image] = 1
- pos_idx.append(pos_idx_per_image_mask)
- neg_idx.append(neg_idx_per_image_mask)
- return pos_idx, neg_idx
- @torch.jit._script_if_tracing
- def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
- """
- Encode a set of proposals with respect to some
- reference boxes
- Args:
- reference_boxes (Tensor): reference boxes
- proposals (Tensor): boxes to be encoded
- weights (Tensor[4]): the weights for ``(x, y, w, h)``
- """
- # perform some unpacking to make it JIT-fusion friendly
- wx = weights[0]
- wy = weights[1]
- ww = weights[2]
- wh = weights[3]
- proposals_x1 = proposals[:, 0].unsqueeze(1)
- proposals_y1 = proposals[:, 1].unsqueeze(1)
- proposals_x2 = proposals[:, 2].unsqueeze(1)
- proposals_y2 = proposals[:, 3].unsqueeze(1)
- reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
- reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
- reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
- reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
- # implementation starts here
- ex_widths = proposals_x2 - proposals_x1
- ex_heights = proposals_y2 - proposals_y1
- ex_ctr_x = proposals_x1 + 0.5 * ex_widths
- ex_ctr_y = proposals_y1 + 0.5 * ex_heights
- gt_widths = reference_boxes_x2 - reference_boxes_x1
- gt_heights = reference_boxes_y2 - reference_boxes_y1
- gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
- gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
- targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
- targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
- targets_dw = ww * torch.log(gt_widths / ex_widths)
- targets_dh = wh * torch.log(gt_heights / ex_heights)
- targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
- return targets
- class BoxCoder:
- """
- This class encodes and decodes a set of bounding boxes into
- the representation used for training the regressors.
- """
- def __init__(
- self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
- ) -> None:
- """
- Args:
- weights (4-element tuple)
- bbox_xform_clip (float)
- """
- self.weights = weights
- self.bbox_xform_clip = bbox_xform_clip
- def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
- boxes_per_image = [len(b) for b in reference_boxes]
- reference_boxes = torch.cat(reference_boxes, dim=0)
- proposals = torch.cat(proposals, dim=0)
- targets = self.encode_single(reference_boxes, proposals)
- return targets.split(boxes_per_image, 0)
- def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
- """
- Encode a set of proposals with respect to some
- reference boxes
- Args:
- reference_boxes (Tensor): reference boxes
- proposals (Tensor): boxes to be encoded
- """
- dtype = reference_boxes.dtype
- device = reference_boxes.device
- weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
- targets = encode_boxes(reference_boxes, proposals, weights)
- return targets
- def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
- torch._assert(
- isinstance(boxes, (list, tuple)),
- "This function expects boxes of type list or tuple.",
- )
- torch._assert(
- isinstance(rel_codes, torch.Tensor),
- "This function expects rel_codes of type torch.Tensor.",
- )
- boxes_per_image = [b.size(0) for b in boxes]
- concat_boxes = torch.cat(boxes, dim=0)
- box_sum = 0
- for val in boxes_per_image:
- box_sum += val
- if box_sum > 0:
- rel_codes = rel_codes.reshape(box_sum, -1)
- pred_boxes = self.decode_single(rel_codes, concat_boxes)
- if box_sum > 0:
- pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
- return pred_boxes
- def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
- """
- From a set of original boxes and encoded relative box offsets,
- get the decoded boxes.
- Args:
- rel_codes (Tensor): encoded boxes
- boxes (Tensor): reference boxes.
- """
- boxes = boxes.to(rel_codes.dtype)
- widths = boxes[:, 2] - boxes[:, 0]
- heights = boxes[:, 3] - boxes[:, 1]
- ctr_x = boxes[:, 0] + 0.5 * widths
- ctr_y = boxes[:, 1] + 0.5 * heights
- wx, wy, ww, wh = self.weights
- dx = rel_codes[:, 0::4] / wx
- dy = rel_codes[:, 1::4] / wy
- dw = rel_codes[:, 2::4] / ww
- dh = rel_codes[:, 3::4] / wh
- # Prevent sending too large values into torch.exp()
- dw = torch.clamp(dw, max=self.bbox_xform_clip)
- dh = torch.clamp(dh, max=self.bbox_xform_clip)
- pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
- pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
- pred_w = torch.exp(dw) * widths[:, None]
- pred_h = torch.exp(dh) * heights[:, None]
- # Distance from center to box's corner.
- c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
- c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
- pred_boxes1 = pred_ctr_x - c_to_c_w
- pred_boxes2 = pred_ctr_y - c_to_c_h
- pred_boxes3 = pred_ctr_x + c_to_c_w
- pred_boxes4 = pred_ctr_y + c_to_c_h
- pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
- return pred_boxes
- class BoxLinearCoder:
- """
- The linear box-to-box transform defined in FCOS. The transformation is parameterized
- by the distance from the center of (square) src box to 4 edges of the target box.
- """
- def __init__(self, normalize_by_size: bool = True) -> None:
- """
- Args:
- normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
- """
- self.normalize_by_size = normalize_by_size
- def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
- """
- Encode a set of proposals with respect to some reference boxes
- Args:
- reference_boxes (Tensor): reference boxes
- proposals (Tensor): boxes to be encoded
- Returns:
- Tensor: the encoded relative box offsets that can be used to
- decode the boxes.
- """
- # get the center of reference_boxes
- reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
- reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
- # get box regression transformation deltas
- target_l = reference_boxes_ctr_x - proposals[..., 0]
- target_t = reference_boxes_ctr_y - proposals[..., 1]
- target_r = proposals[..., 2] - reference_boxes_ctr_x
- target_b = proposals[..., 3] - reference_boxes_ctr_y
- targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
- if self.normalize_by_size:
- reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
- reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
- reference_boxes_size = torch.stack(
- (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
- )
- targets = targets / reference_boxes_size
- return targets
- def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
- """
- From a set of original boxes and encoded relative box offsets,
- get the decoded boxes.
- Args:
- rel_codes (Tensor): encoded boxes
- boxes (Tensor): reference boxes.
- Returns:
- Tensor: the predicted boxes with the encoded relative box offsets.
- .. note::
- This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
- """
- boxes = boxes.to(dtype=rel_codes.dtype)
- ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
- ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
- if self.normalize_by_size:
- boxes_w = boxes[..., 2] - boxes[..., 0]
- boxes_h = boxes[..., 3] - boxes[..., 1]
- list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
- rel_codes = rel_codes * list_box_size
- pred_boxes1 = ctr_x - rel_codes[..., 0]
- pred_boxes2 = ctr_y - rel_codes[..., 1]
- pred_boxes3 = ctr_x + rel_codes[..., 2]
- pred_boxes4 = ctr_y + rel_codes[..., 3]
- pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
- return pred_boxes
- class Matcher:
- """
- This class assigns to each predicted "element" (e.g., a box) a ground-truth
- element. Each predicted element will have exactly zero or one matches; each
- ground-truth element may be assigned to zero or more predicted elements.
- Matching is based on the MxN match_quality_matrix, that characterizes how well
- each (ground-truth, predicted)-pair match. For example, if the elements are
- boxes, the matrix may contain box IoU overlap values.
- The matcher returns a tensor of size N containing the index of the ground-truth
- element m that matches to prediction n. If there is no match, a negative value
- is returned.
- """
- BELOW_LOW_THRESHOLD = -1
- BETWEEN_THRESHOLDS = -2
- __annotations__ = {
- "BELOW_LOW_THRESHOLD": int,
- "BETWEEN_THRESHOLDS": int,
- }
- def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
- """
- Args:
- high_threshold (float): quality values greater than or equal to
- this value are candidate matches.
- low_threshold (float): a lower quality threshold used to stratify
- matches into three levels:
- 1) matches >= high_threshold
- 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
- 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
- allow_low_quality_matches (bool): if True, produce additional matches
- for predictions that have only low-quality match candidates. See
- set_low_quality_matches_ for more details.
- """
- self.BELOW_LOW_THRESHOLD = -1
- self.BETWEEN_THRESHOLDS = -2
- torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
- self.high_threshold = high_threshold
- self.low_threshold = low_threshold
- self.allow_low_quality_matches = allow_low_quality_matches
- def __call__(self, match_quality_matrix: Tensor) -> Tensor:
- """
- Args:
- match_quality_matrix (Tensor[float]): an MxN tensor, containing the
- pairwise quality between M ground-truth elements and N predicted elements.
- Returns:
- matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
- [0, M - 1] or a negative value indicating that prediction i could not
- be matched.
- """
- if match_quality_matrix.numel() == 0:
- # empty targets or proposals not supported during training
- if match_quality_matrix.shape[0] == 0:
- raise ValueError("No ground-truth boxes available for one of the images during training")
- else:
- raise ValueError("No proposal boxes available for one of the images during training")
- # match_quality_matrix is M (gt) x N (predicted)
- # Max over gt elements (dim 0) to find best gt candidate for each prediction
- matched_vals, matches = match_quality_matrix.max(dim=0)
- if self.allow_low_quality_matches:
- all_matches = matches.clone()
- else:
- all_matches = None # type: ignore[assignment]
- # Assign candidate matches with low quality to negative (unassigned) values
- below_low_threshold = matched_vals < self.low_threshold
- between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
- matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
- matches[between_thresholds] = self.BETWEEN_THRESHOLDS
- if self.allow_low_quality_matches:
- if all_matches is None:
- torch._assert(False, "all_matches should not be None")
- else:
- self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
- return matches
- def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
- """
- Produce additional matches for predictions that have only low-quality matches.
- Specifically, for each ground-truth find the set of predictions that have
- maximum overlap with it (including ties); for each prediction in that set, if
- it is unmatched, then match it to the ground-truth with which it has the highest
- quality value.
- """
- # For each gt, find the prediction with which it has the highest quality
- highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
- # Find the highest quality match available, even if it is low, including ties
- gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
- # Example gt_pred_pairs_of_highest_quality:
- # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
- # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
- # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
- # Note how gt items 1, 2, 3, and 5 each have two ties
- pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
- matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
- class SSDMatcher(Matcher):
- def __init__(self, threshold: float) -> None:
- super().__init__(threshold, threshold, allow_low_quality_matches=False)
- def __call__(self, match_quality_matrix: Tensor) -> Tensor:
- matches = super().__call__(match_quality_matrix)
- # For each gt, find the prediction with which it has the highest quality
- _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
- matches[highest_quality_pred_foreach_gt] = torch.arange(
- highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
- )
- return matches
- def overwrite_eps(model: nn.Module, eps: float) -> None:
- """
- This method overwrites the default eps values of all the
- FrozenBatchNorm2d layers of the model with the provided value.
- This is necessary to address the BC-breaking change introduced
- by the bug-fix at pytorch/vision#2933. The overwrite is applied
- only when the pretrained weights are loaded to maintain compatibility
- with previous versions.
- Args:
- model (nn.Module): The model on which we perform the overwrite.
- eps (float): The new value of eps.
- """
- for module in model.modules():
- if isinstance(module, FrozenBatchNorm2d):
- module.eps = eps
- def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
- """
- This method retrieves the number of output channels of a specific model.
- Args:
- model (nn.Module): The model for which we estimate the out_channels.
- It should return a single Tensor or an OrderedDict[Tensor].
- size (Tuple[int, int]): The size (wxh) of the input.
- Returns:
- out_channels (List[int]): A list of the output channels of the model.
- """
- in_training = model.training
- model.eval()
- with torch.no_grad():
- # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
- device = next(model.parameters()).device
- tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
- features = model(tmp_img)
- if isinstance(features, torch.Tensor):
- features = OrderedDict([("0", features)])
- out_channels = [x.size(1) for x in features.values()]
- if in_training:
- model.train()
- return out_channels
- @torch.jit.unused
- def _fake_cast_onnx(v: Tensor) -> int:
- return v # type: ignore[return-value]
- def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
- """
- ONNX spec requires the k-value to be less than or equal to the number of inputs along
- provided dim. Certain models use the number of elements along a particular axis instead of K
- if K exceeds the number of elements along that axis. Previously, python's min() function was
- used to determine whether to use the provided k-value or the specified dim axis value.
- However, in cases where the model is being exported in tracing mode, python min() is
- static causing the model to be traced incorrectly and eventually fail at the topk node.
- In order to avoid this situation, in tracing mode, torch.min() is used instead.
- Args:
- input (Tensor): The original input tensor.
- orig_kval (int): The provided k-value.
- axis(int): Axis along which we retrieve the input size.
- Returns:
- min_kval (int): Appropriately selected k-value.
- """
- if not torch.jit.is_tracing():
- return min(orig_kval, input.size(axis))
- axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
- min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
- return _fake_cast_onnx(min_kval)
- def _box_loss(
- type: str,
- box_coder: BoxCoder,
- anchors_per_image: Tensor,
- matched_gt_boxes_per_image: Tensor,
- bbox_regression_per_image: Tensor,
- cnf: Optional[Dict[str, float]] = None,
- ) -> Tensor:
- torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
- if type == "l1":
- target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
- return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
- elif type == "smooth_l1":
- target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
- beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
- return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
- else:
- bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
- eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
- if type == "ciou":
- return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
- if type == "diou":
- return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
- # otherwise giou
- return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
|