123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417 |
- from typing import Tuple
- import torch
- import torchvision
- from torch import Tensor
- from torchvision.extension import _assert_has_ops
- from ..utils import _log_api_usage_once
- from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh
- from ._utils import _upcast
- def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
- """
- Performs non-maximum suppression (NMS) on the boxes according
- to their intersection-over-union (IoU).
- NMS iteratively removes lower scoring boxes which have an
- IoU greater than iou_threshold with another (higher scoring)
- box.
- If multiple boxes have the exact same score and satisfy the IoU
- criterion with respect to a reference box, the selected box is
- not guaranteed to be the same between CPU and GPU. This is similar
- to the behavior of argsort in PyTorch when repeated values are present.
- Args:
- boxes (Tensor[N, 4])): boxes to perform NMS on. They
- are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
- ``0 <= y1 < y2``.
- scores (Tensor[N]): scores for each one of the boxes
- iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
- Returns:
- Tensor: int64 tensor with the indices of the elements that have been kept
- by NMS, sorted in decreasing order of scores
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(nms)
- _assert_has_ops()
- return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
- def batched_nms(
- boxes: Tensor,
- scores: Tensor,
- idxs: Tensor,
- iou_threshold: float,
- ) -> Tensor:
- """
- Performs non-maximum suppression in a batched fashion.
- Each index value correspond to a category, and NMS
- will not be applied between elements of different categories.
- Args:
- boxes (Tensor[N, 4]): boxes where NMS will be performed. They
- are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
- ``0 <= y1 < y2``.
- scores (Tensor[N]): scores for each one of the boxes
- idxs (Tensor[N]): indices of the categories for each one of the boxes.
- iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
- Returns:
- Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted
- in decreasing order of scores
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(batched_nms)
- # Benchmarks that drove the following thresholds are at
- # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
- if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
- return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
- else:
- return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
- @torch.jit._script_if_tracing
- def _batched_nms_coordinate_trick(
- boxes: Tensor,
- scores: Tensor,
- idxs: Tensor,
- iou_threshold: float,
- ) -> Tensor:
- # strategy: in order to perform NMS independently per class,
- # we add an offset to all the boxes. The offset is dependent
- # only on the class idx, and is large enough so that boxes
- # from different classes do not overlap
- if boxes.numel() == 0:
- return torch.empty((0,), dtype=torch.int64, device=boxes.device)
- max_coordinate = boxes.max()
- offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
- boxes_for_nms = boxes + offsets[:, None]
- keep = nms(boxes_for_nms, scores, iou_threshold)
- return keep
- @torch.jit._script_if_tracing
- def _batched_nms_vanilla(
- boxes: Tensor,
- scores: Tensor,
- idxs: Tensor,
- iou_threshold: float,
- ) -> Tensor:
- # Based on Detectron2 implementation, just manually call nms() on each class independently
- keep_mask = torch.zeros_like(scores, dtype=torch.bool)
- for class_id in torch.unique(idxs):
- curr_indices = torch.where(idxs == class_id)[0]
- curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
- keep_mask[curr_indices[curr_keep_indices]] = True
- keep_indices = torch.where(keep_mask)[0]
- return keep_indices[scores[keep_indices].sort(descending=True)[1]]
- def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
- """
- Remove boxes which contains at least one side smaller than min_size.
- Args:
- boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
- with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- min_size (float): minimum size
- Returns:
- Tensor[K]: indices of the boxes that have both sides
- larger than min_size
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(remove_small_boxes)
- ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
- keep = (ws >= min_size) & (hs >= min_size)
- keep = torch.where(keep)[0]
- return keep
- def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
- """
- Clip boxes so that they lie inside an image of size `size`.
- Args:
- boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
- with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- size (Tuple[height, width]): size of the image
- Returns:
- Tensor[N, 4]: clipped boxes
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(clip_boxes_to_image)
- dim = boxes.dim()
- boxes_x = boxes[..., 0::2]
- boxes_y = boxes[..., 1::2]
- height, width = size
- if torchvision._is_tracing():
- boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
- boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
- boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
- boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
- else:
- boxes_x = boxes_x.clamp(min=0, max=width)
- boxes_y = boxes_y.clamp(min=0, max=height)
- clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
- return clipped_boxes.reshape(boxes.shape)
- def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
- """
- Converts boxes from given in_fmt to out_fmt.
- Supported in_fmt and out_fmt are:
- 'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
- This is the format that torchvision utilities expect.
- 'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
- 'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h
- being width and height.
- Args:
- boxes (Tensor[N, 4]): boxes which will be converted.
- in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
- out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']
- Returns:
- Tensor[N, 4]: Boxes into converted format.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(box_convert)
- allowed_fmts = ("xyxy", "xywh", "cxcywh")
- if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
- raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
- if in_fmt == out_fmt:
- return boxes.clone()
- if in_fmt != "xyxy" and out_fmt != "xyxy":
- # convert to xyxy and change in_fmt xyxy
- if in_fmt == "xywh":
- boxes = _box_xywh_to_xyxy(boxes)
- elif in_fmt == "cxcywh":
- boxes = _box_cxcywh_to_xyxy(boxes)
- in_fmt = "xyxy"
- if in_fmt == "xyxy":
- if out_fmt == "xywh":
- boxes = _box_xyxy_to_xywh(boxes)
- elif out_fmt == "cxcywh":
- boxes = _box_xyxy_to_cxcywh(boxes)
- elif out_fmt == "xyxy":
- if in_fmt == "xywh":
- boxes = _box_xywh_to_xyxy(boxes)
- elif in_fmt == "cxcywh":
- boxes = _box_cxcywh_to_xyxy(boxes)
- return boxes
- def box_area(boxes: Tensor) -> Tensor:
- """
- Computes the area of a set of bounding boxes, which are specified by their
- (x1, y1, x2, y2) coordinates.
- Args:
- boxes (Tensor[N, 4]): boxes for which the area will be computed. They
- are expected to be in (x1, y1, x2, y2) format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Returns:
- Tensor[N]: the area for each box
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(box_area)
- boxes = _upcast(boxes)
- return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
- # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
- # with slight modifications
- def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
- area1 = box_area(boxes1)
- area2 = box_area(boxes2)
- lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
- rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
- wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
- inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
- union = area1[:, None] + area2 - inter
- return inter, union
- def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
- """
- Return intersection-over-union (Jaccard index) between two sets of boxes.
- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- boxes1 (Tensor[N, 4]): first set of boxes
- boxes2 (Tensor[M, 4]): second set of boxes
- Returns:
- Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(box_iou)
- inter, union = _box_inter_union(boxes1, boxes2)
- iou = inter / union
- return iou
- # Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
- def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
- """
- Return generalized intersection-over-union (Jaccard index) between two sets of boxes.
- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- boxes1 (Tensor[N, 4]): first set of boxes
- boxes2 (Tensor[M, 4]): second set of boxes
- Returns:
- Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
- for every element in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(generalized_box_iou)
- inter, union = _box_inter_union(boxes1, boxes2)
- iou = inter / union
- lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
- rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
- whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
- areai = whi[:, :, 0] * whi[:, :, 1]
- return iou - (areai - union) / areai
- def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
- """
- Return complete intersection-over-union (Jaccard index) between two sets of boxes.
- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- boxes1 (Tensor[N, 4]): first set of boxes
- boxes2 (Tensor[M, 4]): second set of boxes
- eps (float, optional): small number to prevent division by zero. Default: 1e-7
- Returns:
- Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
- for every element in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(complete_box_iou)
- boxes1 = _upcast(boxes1)
- boxes2 = _upcast(boxes2)
- diou, iou = _box_diou_iou(boxes1, boxes2, eps)
- w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
- h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
- w_gt = boxes2[:, 2] - boxes2[:, 0]
- h_gt = boxes2[:, 3] - boxes2[:, 1]
- v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
- with torch.no_grad():
- alpha = v / (1 - iou + v + eps)
- return diou - alpha * v
- def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
- """
- Return distance intersection-over-union (Jaccard index) between two sets of boxes.
- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- boxes1 (Tensor[N, 4]): first set of boxes
- boxes2 (Tensor[M, 4]): second set of boxes
- eps (float, optional): small number to prevent division by zero. Default: 1e-7
- Returns:
- Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
- for every element in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(distance_box_iou)
- boxes1 = _upcast(boxes1)
- boxes2 = _upcast(boxes2)
- diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
- return diou
- def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]:
- iou = box_iou(boxes1, boxes2)
- lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
- rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
- whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
- diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
- # centers of boxes
- x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
- y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
- x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
- y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
- # The distance between boxes' centers squared.
- centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + (
- _upcast((y_p[:, None] - y_g[None, :])) ** 2
- )
- # The distance IoU is the IoU penalized by a normalized
- # distance between boxes' centers squared.
- return iou - (centers_distance_squared / diagonal_distance_squared), iou
- def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
- """
- Compute the bounding boxes around the provided masks.
- Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- masks (Tensor[N, H, W]): masks to transform where N is the number of masks
- and (H, W) are the spatial dimensions.
- Returns:
- Tensor[N, 4]: bounding boxes
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(masks_to_boxes)
- if masks.numel() == 0:
- return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
- n = masks.shape[0]
- bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
- for index, mask in enumerate(masks):
- y, x = torch.where(mask != 0)
- bounding_boxes[index, 0] = torch.min(x)
- bounding_boxes[index, 1] = torch.min(y)
- bounding_boxes[index, 2] = torch.max(x)
- bounding_boxes[index, 3] = torch.max(y)
- return bounding_boxes
|