_bounding_boxes.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import annotations
  2. from enum import Enum
  3. from typing import Any, Mapping, Optional, Sequence, Tuple, Union
  4. import torch
  5. from torch.utils._pytree import tree_flatten
  6. from ._tv_tensor import TVTensor
  7. class BoundingBoxFormat(Enum):
  8. """[BETA] Coordinate format of a bounding box.
  9. Available formats are
  10. * ``XYXY``
  11. * ``XYWH``
  12. * ``CXCYWH``
  13. """
  14. XYXY = "XYXY"
  15. XYWH = "XYWH"
  16. CXCYWH = "CXCYWH"
  17. class BoundingBoxes(TVTensor):
  18. """[BETA] :class:`torch.Tensor` subclass for bounding boxes.
  19. .. note::
  20. There should be only one :class:`~torchvision.tv_tensors.BoundingBoxes`
  21. instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
  22. although one :class:`~torchvision.tv_tensors.BoundingBoxes` object can
  23. contain multiple bounding boxes.
  24. Args:
  25. data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
  26. format (BoundingBoxFormat, str): Format of the bounding box.
  27. canvas_size (two-tuple of ints): Height and width of the corresponding image or video.
  28. dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
  29. ``data``.
  30. device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
  31. :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
  32. requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
  33. ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
  34. """
  35. format: BoundingBoxFormat
  36. canvas_size: Tuple[int, int]
  37. @classmethod
  38. def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override]
  39. if check_dims:
  40. if tensor.ndim == 1:
  41. tensor = tensor.unsqueeze(0)
  42. elif tensor.ndim != 2:
  43. raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
  44. if isinstance(format, str):
  45. format = BoundingBoxFormat[format.upper()]
  46. bounding_boxes = tensor.as_subclass(cls)
  47. bounding_boxes.format = format
  48. bounding_boxes.canvas_size = canvas_size
  49. return bounding_boxes
  50. def __new__(
  51. cls,
  52. data: Any,
  53. *,
  54. format: Union[BoundingBoxFormat, str],
  55. canvas_size: Tuple[int, int],
  56. dtype: Optional[torch.dtype] = None,
  57. device: Optional[Union[torch.device, str, int]] = None,
  58. requires_grad: Optional[bool] = None,
  59. ) -> BoundingBoxes:
  60. tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
  61. return cls._wrap(tensor, format=format, canvas_size=canvas_size)
  62. @classmethod
  63. def _wrap_output(
  64. cls,
  65. output: torch.Tensor,
  66. args: Sequence[Any] = (),
  67. kwargs: Optional[Mapping[str, Any]] = None,
  68. ) -> BoundingBoxes:
  69. # If there are BoundingBoxes instances in the output, their metadata got lost when we called
  70. # super().__torch_function__. We need to restore the metadata somehow, so we choose to take
  71. # the metadata from the first bbox in the parameters.
  72. # This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
  73. # something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
  74. flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
  75. first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes))
  76. format, canvas_size = first_bbox_from_args.format, first_bbox_from_args.canvas_size
  77. if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
  78. output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False)
  79. elif isinstance(output, (tuple, list)):
  80. output = type(output)(
  81. BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output
  82. )
  83. return output
  84. def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
  85. return self._make_repr(format=self.format, canvas_size=self.canvas_size)