_auto_augment.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. import math
  2. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
  3. import PIL.Image
  4. import torch
  5. from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
  6. from torchvision import transforms as _transforms, tv_tensors
  7. from torchvision.transforms import _functional_tensor as _FT
  8. from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
  9. from torchvision.transforms.v2.functional._geometry import _check_interpolation
  10. from torchvision.transforms.v2.functional._meta import get_size
  11. from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
  12. from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
  13. ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
  14. class _AutoAugmentBase(Transform):
  15. def __init__(
  16. self,
  17. *,
  18. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  19. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  20. ) -> None:
  21. super().__init__()
  22. self.interpolation = _check_interpolation(interpolation)
  23. self.fill = fill
  24. self._fill = _setup_fill_arg(fill)
  25. def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
  26. params = super()._extract_params_for_v1_transform()
  27. if isinstance(params["fill"], dict):
  28. raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
  29. return params
  30. def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
  31. keys = tuple(dct.keys())
  32. key = keys[int(torch.randint(len(keys), ()))]
  33. return key, dct[key]
  34. def _flatten_and_extract_image_or_video(
  35. self,
  36. inputs: Any,
  37. unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
  38. ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
  39. flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
  40. needs_transform_list = self._needs_transform_list(flat_inputs)
  41. image_or_videos = []
  42. for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
  43. if needs_transform and check_type(
  44. inpt,
  45. (
  46. tv_tensors.Image,
  47. PIL.Image.Image,
  48. is_pure_tensor,
  49. tv_tensors.Video,
  50. ),
  51. ):
  52. image_or_videos.append((idx, inpt))
  53. elif isinstance(inpt, unsupported_types):
  54. raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
  55. if not image_or_videos:
  56. raise TypeError("Found no image in the sample.")
  57. if len(image_or_videos) > 1:
  58. raise TypeError(
  59. f"Auto augment transformations are only properly defined for a single image or video, "
  60. f"but found {len(image_or_videos)}."
  61. )
  62. idx, image_or_video = image_or_videos[0]
  63. return (flat_inputs, spec, idx), image_or_video
  64. def _unflatten_and_insert_image_or_video(
  65. self,
  66. flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
  67. image_or_video: ImageOrVideo,
  68. ) -> Any:
  69. flat_inputs, spec, idx = flat_inputs_with_spec
  70. flat_inputs[idx] = image_or_video
  71. return tree_unflatten(flat_inputs, spec)
  72. def _apply_image_or_video_transform(
  73. self,
  74. image: ImageOrVideo,
  75. transform_id: str,
  76. magnitude: float,
  77. interpolation: Union[InterpolationMode, int],
  78. fill: Dict[Union[Type, str], _FillTypeJIT],
  79. ) -> ImageOrVideo:
  80. fill_ = _get_fill(fill, type(image))
  81. if transform_id == "Identity":
  82. return image
  83. elif transform_id == "ShearX":
  84. # magnitude should be arctan(magnitude)
  85. # official autoaug: (1, level, 0, 0, 1, 0)
  86. # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
  87. # compared to
  88. # torchvision: (1, tan(level), 0, 0, 1, 0)
  89. # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
  90. return F.affine(
  91. image,
  92. angle=0.0,
  93. translate=[0, 0],
  94. scale=1.0,
  95. shear=[math.degrees(math.atan(magnitude)), 0.0],
  96. interpolation=interpolation,
  97. fill=fill_,
  98. center=[0, 0],
  99. )
  100. elif transform_id == "ShearY":
  101. # magnitude should be arctan(magnitude)
  102. # See above
  103. return F.affine(
  104. image,
  105. angle=0.0,
  106. translate=[0, 0],
  107. scale=1.0,
  108. shear=[0.0, math.degrees(math.atan(magnitude))],
  109. interpolation=interpolation,
  110. fill=fill_,
  111. center=[0, 0],
  112. )
  113. elif transform_id == "TranslateX":
  114. return F.affine(
  115. image,
  116. angle=0.0,
  117. translate=[int(magnitude), 0],
  118. scale=1.0,
  119. interpolation=interpolation,
  120. shear=[0.0, 0.0],
  121. fill=fill_,
  122. )
  123. elif transform_id == "TranslateY":
  124. return F.affine(
  125. image,
  126. angle=0.0,
  127. translate=[0, int(magnitude)],
  128. scale=1.0,
  129. interpolation=interpolation,
  130. shear=[0.0, 0.0],
  131. fill=fill_,
  132. )
  133. elif transform_id == "Rotate":
  134. return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
  135. elif transform_id == "Brightness":
  136. return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
  137. elif transform_id == "Color":
  138. return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
  139. elif transform_id == "Contrast":
  140. return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
  141. elif transform_id == "Sharpness":
  142. return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
  143. elif transform_id == "Posterize":
  144. return F.posterize(image, bits=int(magnitude))
  145. elif transform_id == "Solarize":
  146. bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
  147. return F.solarize(image, threshold=bound * magnitude)
  148. elif transform_id == "AutoContrast":
  149. return F.autocontrast(image)
  150. elif transform_id == "Equalize":
  151. return F.equalize(image)
  152. elif transform_id == "Invert":
  153. return F.invert(image)
  154. else:
  155. raise ValueError(f"No transform available for {transform_id}")
  156. class AutoAugment(_AutoAugmentBase):
  157. r"""[BETA] AutoAugment data augmentation method based on
  158. `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
  159. .. v2betastatus:: AutoAugment transform
  160. This transformation works on images and videos only.
  161. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  162. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  163. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  164. Args:
  165. policy (AutoAugmentPolicy, optional): Desired policy enum defined by
  166. :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
  167. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  168. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  169. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  170. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  171. image. If given a number, the value is used for all bands respectively.
  172. """
  173. _v1_transform_cls = _transforms.AutoAugment
  174. _AUGMENTATION_SPACE = {
  175. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  176. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  177. "TranslateX": (
  178. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
  179. True,
  180. ),
  181. "TranslateY": (
  182. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
  183. True,
  184. ),
  185. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
  186. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  187. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  188. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  189. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  190. "Posterize": (
  191. lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
  192. False,
  193. ),
  194. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  195. "AutoContrast": (lambda num_bins, height, width: None, False),
  196. "Equalize": (lambda num_bins, height, width: None, False),
  197. "Invert": (lambda num_bins, height, width: None, False),
  198. }
  199. def __init__(
  200. self,
  201. policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
  202. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  203. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  204. ) -> None:
  205. super().__init__(interpolation=interpolation, fill=fill)
  206. self.policy = policy
  207. self._policies = self._get_policies(policy)
  208. def _get_policies(
  209. self, policy: AutoAugmentPolicy
  210. ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
  211. if policy == AutoAugmentPolicy.IMAGENET:
  212. return [
  213. (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
  214. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  215. (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
  216. (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
  217. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  218. (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
  219. (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
  220. (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
  221. (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
  222. (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
  223. (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
  224. (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
  225. (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
  226. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  227. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  228. (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
  229. (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
  230. (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
  231. (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
  232. (("Color", 0.4, 0), ("Equalize", 0.6, None)),
  233. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  234. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  235. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  236. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  237. (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
  238. ]
  239. elif policy == AutoAugmentPolicy.CIFAR10:
  240. return [
  241. (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
  242. (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
  243. (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
  244. (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
  245. (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
  246. (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
  247. (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
  248. (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
  249. (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
  250. (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
  251. (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
  252. (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
  253. (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
  254. (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
  255. (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
  256. (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
  257. (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
  258. (("Color", 0.9, 9), ("Equalize", 0.6, None)),
  259. (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
  260. (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
  261. (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
  262. (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
  263. (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
  264. (("Equalize", 0.8, None), ("Invert", 0.1, None)),
  265. (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
  266. ]
  267. elif policy == AutoAugmentPolicy.SVHN:
  268. return [
  269. (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
  270. (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
  271. (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
  272. (("Invert", 0.9, None), ("Equalize", 0.6, None)),
  273. (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
  274. (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
  275. (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
  276. (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
  277. (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
  278. (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
  279. (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
  280. (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
  281. (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
  282. (("Invert", 0.9, None), ("Equalize", 0.6, None)),
  283. (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
  284. (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
  285. (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
  286. (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
  287. (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
  288. (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
  289. (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
  290. (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
  291. (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
  292. (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
  293. (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
  294. ]
  295. else:
  296. raise ValueError(f"The provided policy {policy} is not recognized.")
  297. def forward(self, *inputs: Any) -> Any:
  298. flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
  299. height, width = get_size(image_or_video)
  300. policy = self._policies[int(torch.randint(len(self._policies), ()))]
  301. for transform_id, probability, magnitude_idx in policy:
  302. if not torch.rand(()) <= probability:
  303. continue
  304. magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
  305. magnitudes = magnitudes_fn(10, height, width)
  306. if magnitudes is not None:
  307. magnitude = float(magnitudes[magnitude_idx])
  308. if signed and torch.rand(()) <= 0.5:
  309. magnitude *= -1
  310. else:
  311. magnitude = 0.0
  312. image_or_video = self._apply_image_or_video_transform(
  313. image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  314. )
  315. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
  316. class RandAugment(_AutoAugmentBase):
  317. r"""[BETA] RandAugment data augmentation method based on
  318. `"RandAugment: Practical automated data augmentation with a reduced search space"
  319. <https://arxiv.org/abs/1909.13719>`_.
  320. .. v2betastatus:: RandAugment transform
  321. This transformation works on images and videos only.
  322. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  323. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  324. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  325. Args:
  326. num_ops (int, optional): Number of augmentation transformations to apply sequentially.
  327. magnitude (int, optional): Magnitude for all the transformations.
  328. num_magnitude_bins (int, optional): The number of different magnitude values.
  329. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  330. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  331. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  332. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  333. image. If given a number, the value is used for all bands respectively.
  334. """
  335. _v1_transform_cls = _transforms.RandAugment
  336. _AUGMENTATION_SPACE = {
  337. "Identity": (lambda num_bins, height, width: None, False),
  338. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  339. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  340. "TranslateX": (
  341. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
  342. True,
  343. ),
  344. "TranslateY": (
  345. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
  346. True,
  347. ),
  348. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
  349. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  350. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  351. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  352. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  353. "Posterize": (
  354. lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
  355. False,
  356. ),
  357. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  358. "AutoContrast": (lambda num_bins, height, width: None, False),
  359. "Equalize": (lambda num_bins, height, width: None, False),
  360. }
  361. def __init__(
  362. self,
  363. num_ops: int = 2,
  364. magnitude: int = 9,
  365. num_magnitude_bins: int = 31,
  366. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  367. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  368. ) -> None:
  369. super().__init__(interpolation=interpolation, fill=fill)
  370. self.num_ops = num_ops
  371. self.magnitude = magnitude
  372. self.num_magnitude_bins = num_magnitude_bins
  373. def forward(self, *inputs: Any) -> Any:
  374. flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
  375. height, width = get_size(image_or_video)
  376. for _ in range(self.num_ops):
  377. transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
  378. magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
  379. if magnitudes is not None:
  380. magnitude = float(magnitudes[self.magnitude])
  381. if signed and torch.rand(()) <= 0.5:
  382. magnitude *= -1
  383. else:
  384. magnitude = 0.0
  385. image_or_video = self._apply_image_or_video_transform(
  386. image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  387. )
  388. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
  389. class TrivialAugmentWide(_AutoAugmentBase):
  390. r"""[BETA] Dataset-independent data-augmentation with TrivialAugment Wide, as described in
  391. `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
  392. .. v2betastatus:: TrivialAugmentWide transform
  393. This transformation works on images and videos only.
  394. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  395. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  396. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  397. Args:
  398. num_magnitude_bins (int, optional): The number of different magnitude values.
  399. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  400. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  401. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  402. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  403. image. If given a number, the value is used for all bands respectively.
  404. """
  405. _v1_transform_cls = _transforms.TrivialAugmentWide
  406. _AUGMENTATION_SPACE = {
  407. "Identity": (lambda num_bins, height, width: None, False),
  408. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  409. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  410. "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
  411. "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
  412. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
  413. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  414. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  415. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  416. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  417. "Posterize": (
  418. lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
  419. False,
  420. ),
  421. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  422. "AutoContrast": (lambda num_bins, height, width: None, False),
  423. "Equalize": (lambda num_bins, height, width: None, False),
  424. }
  425. def __init__(
  426. self,
  427. num_magnitude_bins: int = 31,
  428. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  429. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  430. ):
  431. super().__init__(interpolation=interpolation, fill=fill)
  432. self.num_magnitude_bins = num_magnitude_bins
  433. def forward(self, *inputs: Any) -> Any:
  434. flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
  435. height, width = get_size(image_or_video)
  436. transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
  437. magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
  438. if magnitudes is not None:
  439. magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
  440. if signed and torch.rand(()) <= 0.5:
  441. magnitude *= -1
  442. else:
  443. magnitude = 0.0
  444. image_or_video = self._apply_image_or_video_transform(
  445. image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  446. )
  447. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
  448. class AugMix(_AutoAugmentBase):
  449. r"""[BETA] AugMix data augmentation method based on
  450. `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
  451. .. v2betastatus:: AugMix transform
  452. This transformation works on images and videos only.
  453. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  454. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  455. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  456. Args:
  457. severity (int, optional): The severity of base augmentation operators. Default is ``3``.
  458. mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
  459. chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
  460. Default is ``-1``.
  461. alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
  462. all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
  463. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  464. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  465. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  466. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  467. image. If given a number, the value is used for all bands respectively.
  468. """
  469. _v1_transform_cls = _transforms.AugMix
  470. _PARTIAL_AUGMENTATION_SPACE = {
  471. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  472. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  473. "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
  474. "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
  475. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
  476. "Posterize": (
  477. lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
  478. False,
  479. ),
  480. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  481. "AutoContrast": (lambda num_bins, height, width: None, False),
  482. "Equalize": (lambda num_bins, height, width: None, False),
  483. }
  484. _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
  485. **_PARTIAL_AUGMENTATION_SPACE,
  486. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  487. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  488. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  489. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  490. }
  491. def __init__(
  492. self,
  493. severity: int = 3,
  494. mixture_width: int = 3,
  495. chain_depth: int = -1,
  496. alpha: float = 1.0,
  497. all_ops: bool = True,
  498. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  499. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  500. ) -> None:
  501. super().__init__(interpolation=interpolation, fill=fill)
  502. self._PARAMETER_MAX = 10
  503. if not (1 <= severity <= self._PARAMETER_MAX):
  504. raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
  505. self.severity = severity
  506. self.mixture_width = mixture_width
  507. self.chain_depth = chain_depth
  508. self.alpha = alpha
  509. self.all_ops = all_ops
  510. def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
  511. # Must be on a separate method so that we can overwrite it in tests.
  512. return torch._sample_dirichlet(params)
  513. def forward(self, *inputs: Any) -> Any:
  514. flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
  515. height, width = get_size(orig_image_or_video)
  516. if isinstance(orig_image_or_video, torch.Tensor):
  517. image_or_video = orig_image_or_video
  518. else: # isinstance(inpt, PIL.Image.Image):
  519. image_or_video = F.pil_to_tensor(orig_image_or_video)
  520. augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
  521. orig_dims = list(image_or_video.shape)
  522. expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
  523. batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
  524. batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
  525. # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
  526. # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
  527. # augmented image or video.
  528. m = self._sample_dirichlet(
  529. torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
  530. )
  531. # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
  532. combined_weights = self._sample_dirichlet(
  533. torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
  534. ) * m[:, 1].reshape([batch_dims[0], -1])
  535. mix = m[:, 0].reshape(batch_dims) * batch
  536. for i in range(self.mixture_width):
  537. aug = batch
  538. depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
  539. for _ in range(depth):
  540. transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
  541. magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
  542. if magnitudes is not None:
  543. magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
  544. if signed and torch.rand(()) <= 0.5:
  545. magnitude *= -1
  546. else:
  547. magnitude = 0.0
  548. aug = self._apply_image_or_video_transform(
  549. aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  550. )
  551. mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
  552. mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
  553. if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
  554. mix = tv_tensors.wrap(mix, like=orig_image_or_video)
  555. elif isinstance(orig_image_or_video, PIL.Image.Image):
  556. mix = F.to_pil_image(mix)
  557. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)