_utils.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from __future__ import annotations
  2. import collections.abc
  3. import numbers
  4. from contextlib import suppress
  5. from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
  6. import PIL.Image
  7. import torch
  8. from torchvision import tv_tensors
  9. from torchvision._utils import sequence_to_str
  10. from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
  11. from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
  12. from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
  13. def _setup_number_or_seq(arg: Union[int, float, Sequence[Union[int, float]]], name: str) -> Sequence[float]:
  14. if not isinstance(arg, (int, float, Sequence)):
  15. raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}")
  16. if isinstance(arg, Sequence) and len(arg) not in (1, 2):
  17. raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}")
  18. if isinstance(arg, Sequence):
  19. for element in arg:
  20. if not isinstance(element, (int, float)):
  21. raise ValueError(f"{name} should be a sequence of numbers. Got {type(element)}")
  22. if isinstance(arg, (int, float)):
  23. arg = [float(arg), float(arg)]
  24. elif isinstance(arg, Sequence):
  25. if len(arg) == 1:
  26. arg = [float(arg[0]), float(arg[0])]
  27. else:
  28. arg = [float(arg[0]), float(arg[1])]
  29. return arg
  30. def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None:
  31. if isinstance(fill, dict):
  32. for value in fill.values():
  33. _check_fill_arg(value)
  34. else:
  35. if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
  36. raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
  37. def _convert_fill_arg(fill: _FillType) -> _FillTypeJIT:
  38. # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
  39. # So, we can't reassign fill to 0
  40. # if fill is None:
  41. # fill = 0
  42. if fill is None:
  43. return fill
  44. if not isinstance(fill, (int, float)):
  45. fill = [float(v) for v in list(fill)]
  46. return fill # type: ignore[return-value]
  47. def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]:
  48. _check_fill_arg(fill)
  49. if isinstance(fill, dict):
  50. for k, v in fill.items():
  51. fill[k] = _convert_fill_arg(v)
  52. return fill # type: ignore[return-value]
  53. else:
  54. return {"others": _convert_fill_arg(fill)}
  55. def _get_fill(fill_dict, inpt_type):
  56. if inpt_type in fill_dict:
  57. return fill_dict[inpt_type]
  58. elif "others" in fill_dict:
  59. return fill_dict["others"]
  60. else:
  61. RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.")
  62. def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
  63. if not isinstance(padding, (numbers.Number, tuple, list)):
  64. raise TypeError("Got inappropriate padding arg")
  65. if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
  66. raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
  67. # TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
  68. # https://github.com/pytorch/vision/issues/6250
  69. def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
  70. if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
  71. raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
  72. def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
  73. """
  74. This heuristic covers three cases:
  75. 1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
  76. classification inputs for MixUp and CutMix (typically after the Dataloder).
  77. 2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
  78. under a label-like (see below) key. This happens for the inputs of detection models.
  79. 3. The input is a dictionary that is structured as the one from 2.
  80. What is "label-like" key? We first search for an case-insensitive match of 'labels' inside the keys of the
  81. dictionary. This is the name our detection models expect. If we can't find that, we look for a case-insensitive
  82. match of the term 'label' anywhere inside the key, i.e. 'FooLaBeLBar'. If we can't find that either, the dictionary
  83. contains no "label-like" key.
  84. """
  85. if isinstance(inputs, (tuple, list)):
  86. inputs = inputs[1]
  87. # MixUp, CutMix
  88. if is_pure_tensor(inputs):
  89. return inputs
  90. if not isinstance(inputs, collections.abc.Mapping):
  91. raise ValueError(
  92. f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple "
  93. f"whose second item is a dictionary or a tensor, but got {inputs} instead."
  94. )
  95. candidate_key = None
  96. with suppress(StopIteration):
  97. candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
  98. if candidate_key is None:
  99. with suppress(StopIteration):
  100. candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
  101. if candidate_key is None:
  102. raise ValueError(
  103. "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
  104. "If there are no labels in the sample by design, pass labels_getter=None."
  105. )
  106. return inputs[candidate_key]
  107. def _parse_labels_getter(
  108. labels_getter: Union[str, Callable[[Any], Optional[torch.Tensor]], None]
  109. ) -> Callable[[Any], Optional[torch.Tensor]]:
  110. if labels_getter == "default":
  111. return _find_labels_default_heuristic
  112. elif callable(labels_getter):
  113. return labels_getter
  114. elif labels_getter is None:
  115. return lambda _: None
  116. else:
  117. raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
  118. def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
  119. # This assumes there is only one bbox per sample as per the general convention
  120. try:
  121. return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes))
  122. except StopIteration:
  123. raise ValueError("No bounding boxes were found in the sample")
  124. def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
  125. chws = {
  126. tuple(get_dimensions(inpt))
  127. for inpt in flat_inputs
  128. if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
  129. }
  130. if not chws:
  131. raise TypeError("No image or video was found in the sample")
  132. elif len(chws) > 1:
  133. raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
  134. c, h, w = chws.pop()
  135. return c, h, w
  136. def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
  137. sizes = {
  138. tuple(get_size(inpt))
  139. for inpt in flat_inputs
  140. if check_type(
  141. inpt,
  142. (
  143. is_pure_tensor,
  144. tv_tensors.Image,
  145. PIL.Image.Image,
  146. tv_tensors.Video,
  147. tv_tensors.Mask,
  148. tv_tensors.BoundingBoxes,
  149. ),
  150. )
  151. }
  152. if not sizes:
  153. raise TypeError("No image, video, mask or bounding box was found in the sample")
  154. elif len(sizes) > 1:
  155. raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
  156. h, w = sizes.pop()
  157. return h, w
  158. def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
  159. for type_or_check in types_or_checks:
  160. if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
  161. return True
  162. return False
  163. def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
  164. for inpt in flat_inputs:
  165. if check_type(inpt, types_or_checks):
  166. return True
  167. return False
  168. def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
  169. for type_or_check in types_or_checks:
  170. for inpt in flat_inputs:
  171. if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
  172. break
  173. else:
  174. return False
  175. return True