ops.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. import math
  4. import re
  5. import time
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import torch.nn.functional as F
  10. import torchvision
  11. from ultralytics.utils import LOGGER
  12. class Profile(contextlib.ContextDecorator):
  13. """
  14. YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
  15. """
  16. def __init__(self, t=0.0):
  17. """
  18. Initialize the Profile class.
  19. Args:
  20. t (float): Initial time. Defaults to 0.0.
  21. """
  22. self.t = t
  23. self.cuda = torch.cuda.is_available()
  24. def __enter__(self):
  25. """Start timing."""
  26. self.start = self.time()
  27. return self
  28. def __exit__(self, type, value, traceback): # noqa
  29. """Stop timing."""
  30. self.dt = self.time() - self.start # delta-time
  31. self.t += self.dt # accumulate dt
  32. def time(self):
  33. """Get current time."""
  34. if self.cuda:
  35. torch.cuda.synchronize()
  36. return time.time()
  37. def segment2box(segment, width=640, height=640):
  38. """
  39. Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
  40. Args:
  41. segment (torch.Tensor): the segment label
  42. width (int): the width of the image. Defaults to 640
  43. height (int): The height of the image. Defaults to 640
  44. Returns:
  45. (np.ndarray): the minimum and maximum x and y values of the segment.
  46. """
  47. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  48. x, y = segment.T # segment xy
  49. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  50. x, y, = x[inside], y[inside]
  51. return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
  52. 4, dtype=segment.dtype) # xyxy
  53. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
  54. """
  55. Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
  56. (img1_shape) to the shape of a different image (img0_shape).
  57. Args:
  58. img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
  59. boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
  60. img0_shape (tuple): the shape of the target image, in the format of (height, width).
  61. ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
  62. calculated based on the size difference between the two images.
  63. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
  64. rescaling.
  65. Returns:
  66. boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
  67. """
  68. if ratio_pad is None: # calculate from img0_shape
  69. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  70. pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
  71. (img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
  72. else:
  73. gain = ratio_pad[0][0]
  74. pad = ratio_pad[1]
  75. if padding:
  76. boxes[..., [0, 2]] -= pad[0] # x padding
  77. boxes[..., [1, 3]] -= pad[1] # y padding
  78. boxes[..., :4] /= gain
  79. clip_boxes(boxes, img0_shape)
  80. return boxes
  81. def make_divisible(x, divisor):
  82. """
  83. Returns the nearest number that is divisible by the given divisor.
  84. Args:
  85. x (int): The number to make divisible.
  86. divisor (int | torch.Tensor): The divisor.
  87. Returns:
  88. (int): The nearest number divisible by the divisor.
  89. """
  90. if isinstance(divisor, torch.Tensor):
  91. divisor = int(divisor.max()) # to int
  92. return math.ceil(x / divisor) * divisor
  93. def non_max_suppression(
  94. prediction,
  95. conf_thres=0.25,
  96. iou_thres=0.45,
  97. classes=None,
  98. agnostic=False,
  99. multi_label=False,
  100. labels=(),
  101. max_det=300,
  102. nc=0, # number of classes (optional)
  103. max_time_img=0.05,
  104. max_nms=30000,
  105. max_wh=7680,
  106. ):
  107. """
  108. Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
  109. Args:
  110. prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
  111. containing the predicted boxes, classes, and masks. The tensor should be in the format
  112. output by a model, such as YOLO.
  113. conf_thres (float): The confidence threshold below which boxes will be filtered out.
  114. Valid values are between 0.0 and 1.0.
  115. iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
  116. Valid values are between 0.0 and 1.0.
  117. classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
  118. agnostic (bool): If True, the model is agnostic to the number of classes, and all
  119. classes will be considered as one.
  120. multi_label (bool): If True, each box may have multiple labels.
  121. labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
  122. list contains the apriori labels for a given image. The list should be in the format
  123. output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
  124. max_det (int): The maximum number of boxes to keep after NMS.
  125. nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
  126. max_time_img (float): The maximum time (seconds) for processing one image.
  127. max_nms (int): The maximum number of boxes into torchvision.ops.nms().
  128. max_wh (int): The maximum box width and height in pixels
  129. Returns:
  130. (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
  131. shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
  132. (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
  133. """
  134. # Checks
  135. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  136. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  137. if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
  138. prediction = prediction[0] # select only inference output
  139. device = prediction.device
  140. mps = 'mps' in device.type # Apple MPS
  141. if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
  142. prediction = prediction.cpu()
  143. bs = prediction.shape[0] # batch size
  144. nc = nc or (prediction.shape[1] - 4) # number of classes
  145. nm = prediction.shape[1] - nc - 4
  146. mi = 4 + nc # mask start index
  147. xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
  148. # Settings
  149. # min_wh = 2 # (pixels) minimum box width and height
  150. time_limit = 0.5 + max_time_img * bs # seconds to quit after
  151. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  152. prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
  153. prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
  154. t = time.time()
  155. output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  156. for xi, x in enumerate(prediction): # image index, image inference
  157. # Apply constraints
  158. # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
  159. x = x[xc[xi]] # confidence
  160. # Cat apriori labels if autolabelling
  161. if labels and len(labels[xi]):
  162. lb = labels[xi]
  163. v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
  164. v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
  165. v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
  166. x = torch.cat((x, v), 0)
  167. # If none remain process next image
  168. if not x.shape[0]:
  169. continue
  170. # Detections matrix nx6 (xyxy, conf, cls)
  171. box, cls, mask = x.split((4, nc, nm), 1)
  172. if multi_label:
  173. i, j = torch.where(cls > conf_thres)
  174. x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
  175. else: # best class only
  176. conf, j = cls.max(1, keepdim=True)
  177. x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  178. # Filter by class
  179. if classes is not None:
  180. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  181. # Check shape
  182. n = x.shape[0] # number of boxes
  183. if not n: # no boxes
  184. continue
  185. if n > max_nms: # excess boxes
  186. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
  187. # Batched NMS
  188. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  189. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  190. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  191. i = i[:max_det] # limit detections
  192. # # Experimental
  193. # merge = False # use merge-NMS
  194. # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  195. # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  196. # from .metrics import box_iou
  197. # iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  198. # weights = iou * scores[None] # box weights
  199. # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  200. # redundant = True # require redundant detections
  201. # if redundant:
  202. # i = i[iou.sum(1) > 1] # require redundancy
  203. output[xi] = x[i]
  204. if mps:
  205. output[xi] = output[xi].to(device)
  206. if (time.time() - t) > time_limit:
  207. LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
  208. break # time limit exceeded
  209. return output
  210. def clip_boxes(boxes, shape):
  211. """
  212. Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
  213. Args:
  214. boxes (torch.Tensor): the bounding boxes to clip
  215. shape (tuple): the shape of the image
  216. """
  217. if isinstance(boxes, torch.Tensor): # faster individually
  218. boxes[..., 0].clamp_(0, shape[1]) # x1
  219. boxes[..., 1].clamp_(0, shape[0]) # y1
  220. boxes[..., 2].clamp_(0, shape[1]) # x2
  221. boxes[..., 3].clamp_(0, shape[0]) # y2
  222. else: # np.array (faster grouped)
  223. boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
  224. boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
  225. def clip_coords(coords, shape):
  226. """
  227. Clip line coordinates to the image boundaries.
  228. Args:
  229. coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
  230. shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
  231. Returns:
  232. (None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
  233. """
  234. if isinstance(coords, torch.Tensor): # faster individually
  235. coords[..., 0].clamp_(0, shape[1]) # x
  236. coords[..., 1].clamp_(0, shape[0]) # y
  237. else: # np.array (faster grouped)
  238. coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
  239. coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
  240. def scale_image(masks, im0_shape, ratio_pad=None):
  241. """
  242. Takes a mask, and resizes it to the original image size
  243. Args:
  244. masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
  245. im0_shape (tuple): the original image shape
  246. ratio_pad (tuple): the ratio of the padding to the original image.
  247. Returns:
  248. masks (torch.Tensor): The masks that are being returned.
  249. """
  250. # Rescale coordinates (xyxy) from im1_shape to im0_shape
  251. im1_shape = masks.shape
  252. if im1_shape[:2] == im0_shape[:2]:
  253. return masks
  254. if ratio_pad is None: # calculate from im0_shape
  255. gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
  256. pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
  257. else:
  258. gain = ratio_pad[0][0]
  259. pad = ratio_pad[1]
  260. top, left = int(pad[1]), int(pad[0]) # y, x
  261. bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
  262. if len(masks.shape) < 2:
  263. raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
  264. masks = masks[top:bottom, left:right]
  265. masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
  266. if len(masks.shape) == 2:
  267. masks = masks[:, :, None]
  268. return masks
  269. def xyxy2xywh(x):
  270. """
  271. Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
  272. top-left corner and (x2, y2) is the bottom-right corner.
  273. Args:
  274. x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
  275. Returns:
  276. y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
  277. """
  278. assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
  279. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  280. y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
  281. y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
  282. y[..., 2] = x[..., 2] - x[..., 0] # width
  283. y[..., 3] = x[..., 3] - x[..., 1] # height
  284. return y
  285. def xywh2xyxy(x):
  286. """
  287. Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
  288. top-left corner and (x2, y2) is the bottom-right corner.
  289. Args:
  290. x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
  291. Returns:
  292. y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
  293. """
  294. assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
  295. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  296. dw = x[..., 2] / 2 # half-width
  297. dh = x[..., 3] / 2 # half-height
  298. y[..., 0] = x[..., 0] - dw # top left x
  299. y[..., 1] = x[..., 1] - dh # top left y
  300. y[..., 2] = x[..., 0] + dw # bottom right x
  301. y[..., 3] = x[..., 1] + dh # bottom right y
  302. return y
  303. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  304. """
  305. Convert normalized bounding box coordinates to pixel coordinates.
  306. Args:
  307. x (np.ndarray | torch.Tensor): The bounding box coordinates.
  308. w (int): Width of the image. Defaults to 640
  309. h (int): Height of the image. Defaults to 640
  310. padw (int): Padding width. Defaults to 0
  311. padh (int): Padding height. Defaults to 0
  312. Returns:
  313. y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
  314. x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
  315. """
  316. assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
  317. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  318. y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
  319. y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
  320. y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
  321. y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
  322. return y
  323. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  324. """
  325. Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
  326. x, y, width and height are normalized to image dimensions
  327. Args:
  328. x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
  329. w (int): The width of the image. Defaults to 640
  330. h (int): The height of the image. Defaults to 640
  331. clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
  332. eps (float): The minimum value of the box's width and height. Defaults to 0.0
  333. Returns:
  334. y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
  335. """
  336. if clip:
  337. clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
  338. assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
  339. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  340. y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
  341. y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
  342. y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
  343. y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
  344. return y
  345. def xywh2ltwh(x):
  346. """
  347. Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
  348. Args:
  349. x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
  350. Returns:
  351. y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
  352. """
  353. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  354. y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
  355. y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
  356. return y
  357. def xyxy2ltwh(x):
  358. """
  359. Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
  360. Args:
  361. x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
  362. Returns:
  363. y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
  364. """
  365. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  366. y[..., 2] = x[..., 2] - x[..., 0] # width
  367. y[..., 3] = x[..., 3] - x[..., 1] # height
  368. return y
  369. def ltwh2xywh(x):
  370. """
  371. Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
  372. Args:
  373. x (torch.Tensor): the input tensor
  374. Returns:
  375. y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
  376. """
  377. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  378. y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
  379. y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
  380. return y
  381. def xyxyxyxy2xywhr(corners):
  382. """
  383. Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
  384. Args:
  385. corners (numpy.ndarray | torch.Tensor): Input corners of shape (n, 8).
  386. Returns:
  387. (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
  388. """
  389. is_numpy = isinstance(corners, np.ndarray)
  390. atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
  391. x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
  392. cx = (x1 + x3) / 2
  393. cy = (y1 + y3) / 2
  394. dx21 = x2 - x1
  395. dy21 = y2 - y1
  396. w = sqrt(dx21 ** 2 + dy21 ** 2)
  397. h = sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
  398. rotation = atan2(-dy21, dx21)
  399. rotation *= 180.0 / math.pi # radians to degrees
  400. return np.vstack((cx, cy, w, h, rotation)).T if is_numpy else torch.stack((cx, cy, w, h, rotation), dim=1)
  401. def xywhr2xyxyxyxy(center):
  402. """
  403. Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
  404. Args:
  405. center (numpy.ndarray | torch.Tensor): Input data in [cx, cy, w, h, rotation] format of shape (n, 5).
  406. Returns:
  407. (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
  408. """
  409. is_numpy = isinstance(center, np.ndarray)
  410. cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
  411. cx, cy, w, h, rotation = center.T
  412. rotation *= math.pi / 180.0 # degrees to radians
  413. dx = w / 2
  414. dy = h / 2
  415. cos_rot = cos(rotation)
  416. sin_rot = sin(rotation)
  417. dx_cos_rot = dx * cos_rot
  418. dx_sin_rot = dx * sin_rot
  419. dy_cos_rot = dy * cos_rot
  420. dy_sin_rot = dy * sin_rot
  421. x1 = cx - dx_cos_rot - dy_sin_rot
  422. y1 = cy + dx_sin_rot - dy_cos_rot
  423. x2 = cx + dx_cos_rot - dy_sin_rot
  424. y2 = cy - dx_sin_rot - dy_cos_rot
  425. x3 = cx + dx_cos_rot + dy_sin_rot
  426. y3 = cy - dx_sin_rot + dy_cos_rot
  427. x4 = cx - dx_cos_rot + dy_sin_rot
  428. y4 = cy + dx_sin_rot + dy_cos_rot
  429. return np.vstack((x1, y1, x2, y2, x3, y3, x4, y4)).T if is_numpy else torch.stack(
  430. (x1, y1, x2, y2, x3, y3, x4, y4), dim=1)
  431. def ltwh2xyxy(x):
  432. """
  433. It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  434. Args:
  435. x (np.ndarray | torch.Tensor): the input image
  436. Returns:
  437. y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
  438. """
  439. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  440. y[..., 2] = x[..., 2] + x[..., 0] # width
  441. y[..., 3] = x[..., 3] + x[..., 1] # height
  442. return y
  443. def segments2boxes(segments):
  444. """
  445. It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  446. Args:
  447. segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
  448. Returns:
  449. (np.ndarray): the xywh coordinates of the bounding boxes.
  450. """
  451. boxes = []
  452. for s in segments:
  453. x, y = s.T # segment xy
  454. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  455. return xyxy2xywh(np.array(boxes)) # cls, xywh
  456. def resample_segments(segments, n=1000):
  457. """
  458. Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
  459. Args:
  460. segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
  461. n (int): number of points to resample the segment to. Defaults to 1000
  462. Returns:
  463. segments (list): the resampled segments.
  464. """
  465. for i, s in enumerate(segments):
  466. s = np.concatenate((s, s[0:1, :]), axis=0)
  467. x = np.linspace(0, len(s) - 1, n)
  468. xp = np.arange(len(s))
  469. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
  470. dtype=np.float32).reshape(2, -1).T # segment xy
  471. return segments
  472. def crop_mask(masks, boxes):
  473. """
  474. It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
  475. Args:
  476. masks (torch.Tensor): [n, h, w] tensor of masks
  477. boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
  478. Returns:
  479. (torch.Tensor): The masks are being cropped to the bounding box.
  480. """
  481. n, h, w = masks.shape
  482. x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
  483. r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
  484. c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
  485. return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
  486. def process_mask_upsample(protos, masks_in, bboxes, shape):
  487. """
  488. Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
  489. quality but is slower.
  490. Args:
  491. protos (torch.Tensor): [mask_dim, mask_h, mask_w]
  492. masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
  493. bboxes (torch.Tensor): [n, 4], n is number of masks after nms
  494. shape (tuple): the size of the input image (h,w)
  495. Returns:
  496. (torch.Tensor): The upsampled masks.
  497. """
  498. c, mh, mw = protos.shape # CHW
  499. masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
  500. masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
  501. masks = crop_mask(masks, bboxes) # CHW
  502. return masks.gt_(0.5)
  503. def process_mask(protos, masks_in, bboxes, shape, upsample=False):
  504. """
  505. Apply masks to bounding boxes using the output of the mask head.
  506. Args:
  507. protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
  508. masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
  509. bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
  510. shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
  511. upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
  512. Returns:
  513. (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
  514. are the height and width of the input image. The mask is applied to the bounding boxes.
  515. """
  516. c, mh, mw = protos.shape # CHW
  517. ih, iw = shape
  518. masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
  519. downsampled_bboxes = bboxes.clone()
  520. downsampled_bboxes[:, 0] *= mw / iw
  521. downsampled_bboxes[:, 2] *= mw / iw
  522. downsampled_bboxes[:, 3] *= mh / ih
  523. downsampled_bboxes[:, 1] *= mh / ih
  524. masks = crop_mask(masks, downsampled_bboxes) # CHW
  525. if upsample:
  526. masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
  527. return masks.gt_(0.5)
  528. def process_mask_native(protos, masks_in, bboxes, shape):
  529. """
  530. It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
  531. Args:
  532. protos (torch.Tensor): [mask_dim, mask_h, mask_w]
  533. masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
  534. bboxes (torch.Tensor): [n, 4], n is number of masks after nms
  535. shape (tuple): the size of the input image (h,w)
  536. Returns:
  537. masks (torch.Tensor): The returned masks with dimensions [h, w, n]
  538. """
  539. c, mh, mw = protos.shape # CHW
  540. masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
  541. masks = scale_masks(masks[None], shape)[0] # CHW
  542. masks = crop_mask(masks, bboxes) # CHW
  543. return masks.gt_(0.5)
  544. def scale_masks(masks, shape, padding=True):
  545. """
  546. Rescale segment masks to shape.
  547. Args:
  548. masks (torch.Tensor): (N, C, H, W).
  549. shape (tuple): Height and width.
  550. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
  551. rescaling.
  552. """
  553. mh, mw = masks.shape[2:]
  554. gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
  555. pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
  556. if padding:
  557. pad[0] /= 2
  558. pad[1] /= 2
  559. top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
  560. bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
  561. masks = masks[..., top:bottom, left:right]
  562. masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
  563. return masks
  564. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
  565. """
  566. Rescale segment coordinates (xy) from img1_shape to img0_shape
  567. Args:
  568. img1_shape (tuple): The shape of the image that the coords are from.
  569. coords (torch.Tensor): the coords to be scaled of shape n,2.
  570. img0_shape (tuple): the shape of the image that the segmentation is being applied to.
  571. ratio_pad (tuple): the ratio of the image size to the padded image size.
  572. normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
  573. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
  574. rescaling.
  575. Returns:
  576. coords (torch.Tensor): The scaled coordinates.
  577. """
  578. if ratio_pad is None: # calculate from img0_shape
  579. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  580. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  581. else:
  582. gain = ratio_pad[0][0]
  583. pad = ratio_pad[1]
  584. if padding:
  585. coords[..., 0] -= pad[0] # x padding
  586. coords[..., 1] -= pad[1] # y padding
  587. coords[..., 0] /= gain
  588. coords[..., 1] /= gain
  589. clip_coords(coords, img0_shape)
  590. if normalize:
  591. coords[..., 0] /= img0_shape[1] # width
  592. coords[..., 1] /= img0_shape[0] # height
  593. return coords
  594. def masks2segments(masks, strategy='largest'):
  595. """
  596. It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
  597. Args:
  598. masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
  599. strategy (str): 'concat' or 'largest'. Defaults to largest
  600. Returns:
  601. segments (List): list of segment masks
  602. """
  603. segments = []
  604. for x in masks.int().cpu().numpy().astype('uint8'):
  605. c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
  606. if c:
  607. if strategy == 'concat': # concatenate all segments
  608. c = np.concatenate([x.reshape(-1, 2) for x in c])
  609. elif strategy == 'largest': # select largest segment
  610. c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
  611. else:
  612. c = np.zeros((0, 2)) # no segments found
  613. segments.append(c.astype('float32'))
  614. return segments
  615. def clean_str(s):
  616. """
  617. Cleans a string by replacing special characters with underscore _
  618. Args:
  619. s (str): a string needing special characters replaced
  620. Returns:
  621. (str): a string with special characters replaced by an underscore _
  622. """
  623. return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)