_meta.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from typing import List, Optional, Tuple
  2. import PIL.Image
  3. import torch
  4. from torchvision import tv_tensors
  5. from torchvision.transforms import _functional_pil as _FP
  6. from torchvision.tv_tensors import BoundingBoxFormat
  7. from torchvision.utils import _log_api_usage_once
  8. from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
  9. def get_dimensions(inpt: torch.Tensor) -> List[int]:
  10. if torch.jit.is_scripting():
  11. return get_dimensions_image(inpt)
  12. _log_api_usage_once(get_dimensions)
  13. kernel = _get_kernel(get_dimensions, type(inpt))
  14. return kernel(inpt)
  15. @_register_kernel_internal(get_dimensions, torch.Tensor)
  16. @_register_kernel_internal(get_dimensions, tv_tensors.Image, tv_tensor_wrapper=False)
  17. def get_dimensions_image(image: torch.Tensor) -> List[int]:
  18. chw = list(image.shape[-3:])
  19. ndims = len(chw)
  20. if ndims == 3:
  21. return chw
  22. elif ndims == 2:
  23. chw.insert(0, 1)
  24. return chw
  25. else:
  26. raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
  27. _get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
  28. @_register_kernel_internal(get_dimensions, tv_tensors.Video, tv_tensor_wrapper=False)
  29. def get_dimensions_video(video: torch.Tensor) -> List[int]:
  30. return get_dimensions_image(video)
  31. def get_num_channels(inpt: torch.Tensor) -> int:
  32. if torch.jit.is_scripting():
  33. return get_num_channels_image(inpt)
  34. _log_api_usage_once(get_num_channels)
  35. kernel = _get_kernel(get_num_channels, type(inpt))
  36. return kernel(inpt)
  37. @_register_kernel_internal(get_num_channels, torch.Tensor)
  38. @_register_kernel_internal(get_num_channels, tv_tensors.Image, tv_tensor_wrapper=False)
  39. def get_num_channels_image(image: torch.Tensor) -> int:
  40. chw = image.shape[-3:]
  41. ndims = len(chw)
  42. if ndims == 3:
  43. return chw[0]
  44. elif ndims == 2:
  45. return 1
  46. else:
  47. raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
  48. _get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
  49. @_register_kernel_internal(get_num_channels, tv_tensors.Video, tv_tensor_wrapper=False)
  50. def get_num_channels_video(video: torch.Tensor) -> int:
  51. return get_num_channels_image(video)
  52. # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
  53. # deprecating the old names.
  54. get_image_num_channels = get_num_channels
  55. def get_size(inpt: torch.Tensor) -> List[int]:
  56. if torch.jit.is_scripting():
  57. return get_size_image(inpt)
  58. _log_api_usage_once(get_size)
  59. kernel = _get_kernel(get_size, type(inpt))
  60. return kernel(inpt)
  61. @_register_kernel_internal(get_size, torch.Tensor)
  62. @_register_kernel_internal(get_size, tv_tensors.Image, tv_tensor_wrapper=False)
  63. def get_size_image(image: torch.Tensor) -> List[int]:
  64. hw = list(image.shape[-2:])
  65. ndims = len(hw)
  66. if ndims == 2:
  67. return hw
  68. else:
  69. raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
  70. @_register_kernel_internal(get_size, PIL.Image.Image)
  71. def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
  72. width, height = _FP.get_image_size(image)
  73. return [height, width]
  74. @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
  75. def get_size_video(video: torch.Tensor) -> List[int]:
  76. return get_size_image(video)
  77. @_register_kernel_internal(get_size, tv_tensors.Mask, tv_tensor_wrapper=False)
  78. def get_size_mask(mask: torch.Tensor) -> List[int]:
  79. return get_size_image(mask)
  80. @_register_kernel_internal(get_size, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  81. def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> List[int]:
  82. return list(bounding_box.canvas_size)
  83. def get_num_frames(inpt: torch.Tensor) -> int:
  84. if torch.jit.is_scripting():
  85. return get_num_frames_video(inpt)
  86. _log_api_usage_once(get_num_frames)
  87. kernel = _get_kernel(get_num_frames, type(inpt))
  88. return kernel(inpt)
  89. @_register_kernel_internal(get_num_frames, torch.Tensor)
  90. @_register_kernel_internal(get_num_frames, tv_tensors.Video, tv_tensor_wrapper=False)
  91. def get_num_frames_video(video: torch.Tensor) -> int:
  92. return video.shape[-4]
  93. def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
  94. xyxy = xywh if inplace else xywh.clone()
  95. xyxy[..., 2:] += xyxy[..., :2]
  96. return xyxy
  97. def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
  98. xywh = xyxy if inplace else xyxy.clone()
  99. xywh[..., 2:] -= xywh[..., :2]
  100. return xywh
  101. def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
  102. if not inplace:
  103. cxcywh = cxcywh.clone()
  104. # Trick to do fast division by 2 and ceil, without casting. It produces the same result as
  105. # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
  106. half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
  107. # (cx - width / 2) = x1, same for y1
  108. cxcywh[..., :2].sub_(half_wh)
  109. # (x1 + width) = x2, same for y2
  110. cxcywh[..., 2:].add_(cxcywh[..., :2])
  111. return cxcywh
  112. def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
  113. if not inplace:
  114. xyxy = xyxy.clone()
  115. # (x2 - x1) = width, same for height
  116. xyxy[..., 2:].sub_(xyxy[..., :2])
  117. # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy
  118. xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")
  119. return xyxy
  120. def _convert_bounding_box_format(
  121. bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
  122. ) -> torch.Tensor:
  123. if new_format == old_format:
  124. return bounding_boxes
  125. # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
  126. if old_format == BoundingBoxFormat.XYWH:
  127. bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace)
  128. elif old_format == BoundingBoxFormat.CXCYWH:
  129. bounding_boxes = _cxcywh_to_xyxy(bounding_boxes, inplace)
  130. if new_format == BoundingBoxFormat.XYWH:
  131. bounding_boxes = _xyxy_to_xywh(bounding_boxes, inplace)
  132. elif new_format == BoundingBoxFormat.CXCYWH:
  133. bounding_boxes = _xyxy_to_cxcywh(bounding_boxes, inplace)
  134. return bounding_boxes
  135. def convert_bounding_box_format(
  136. inpt: torch.Tensor,
  137. old_format: Optional[BoundingBoxFormat] = None,
  138. new_format: Optional[BoundingBoxFormat] = None,
  139. inplace: bool = False,
  140. ) -> torch.Tensor:
  141. """[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details."""
  142. # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
  143. # inputs as well as extract it from `tv_tensors.BoundingBoxes` inputs. However, putting a default value on
  144. # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
  145. # default error that would be thrown if `new_format` had no default value.
  146. if new_format is None:
  147. raise TypeError("convert_bounding_box_format() missing 1 required argument: 'new_format'")
  148. if not torch.jit.is_scripting():
  149. _log_api_usage_once(convert_bounding_box_format)
  150. if torch.jit.is_scripting() or is_pure_tensor(inpt):
  151. if old_format is None:
  152. raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
  153. return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
  154. elif isinstance(inpt, tv_tensors.BoundingBoxes):
  155. if old_format is not None:
  156. raise ValueError("For bounding box tv_tensor inputs, `old_format` must not be passed.")
  157. output = _convert_bounding_box_format(
  158. inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace
  159. )
  160. return tv_tensors.wrap(output, like=inpt, format=new_format)
  161. else:
  162. raise TypeError(
  163. f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
  164. )
  165. def _clamp_bounding_boxes(
  166. bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: Tuple[int, int]
  167. ) -> torch.Tensor:
  168. # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
  169. # BoundingBoxFormat instead of converting back and forth
  170. in_dtype = bounding_boxes.dtype
  171. bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
  172. xyxy_boxes = convert_bounding_box_format(
  173. bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
  174. )
  175. xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
  176. xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
  177. out_boxes = convert_bounding_box_format(
  178. xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
  179. )
  180. return out_boxes.to(in_dtype)
  181. def clamp_bounding_boxes(
  182. inpt: torch.Tensor,
  183. format: Optional[BoundingBoxFormat] = None,
  184. canvas_size: Optional[Tuple[int, int]] = None,
  185. ) -> torch.Tensor:
  186. """[BETA] See :func:`~torchvision.transforms.v2.ClampBoundingBoxes` for details."""
  187. if not torch.jit.is_scripting():
  188. _log_api_usage_once(clamp_bounding_boxes)
  189. if torch.jit.is_scripting() or is_pure_tensor(inpt):
  190. if format is None or canvas_size is None:
  191. raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
  192. return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
  193. elif isinstance(inpt, tv_tensors.BoundingBoxes):
  194. if format is not None or canvas_size is not None:
  195. raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.")
  196. output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
  197. return tv_tensors.wrap(output, like=inpt)
  198. else:
  199. raise TypeError(
  200. f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead."
  201. )