autoaugment.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. import math
  2. from enum import Enum
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from torch import Tensor
  6. from . import functional as F, InterpolationMode
  7. __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
  8. def _apply_op(
  9. img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
  10. ):
  11. if op_name == "ShearX":
  12. # magnitude should be arctan(magnitude)
  13. # official autoaug: (1, level, 0, 0, 1, 0)
  14. # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
  15. # compared to
  16. # torchvision: (1, tan(level), 0, 0, 1, 0)
  17. # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
  18. img = F.affine(
  19. img,
  20. angle=0.0,
  21. translate=[0, 0],
  22. scale=1.0,
  23. shear=[math.degrees(math.atan(magnitude)), 0.0],
  24. interpolation=interpolation,
  25. fill=fill,
  26. center=[0, 0],
  27. )
  28. elif op_name == "ShearY":
  29. # magnitude should be arctan(magnitude)
  30. # See above
  31. img = F.affine(
  32. img,
  33. angle=0.0,
  34. translate=[0, 0],
  35. scale=1.0,
  36. shear=[0.0, math.degrees(math.atan(magnitude))],
  37. interpolation=interpolation,
  38. fill=fill,
  39. center=[0, 0],
  40. )
  41. elif op_name == "TranslateX":
  42. img = F.affine(
  43. img,
  44. angle=0.0,
  45. translate=[int(magnitude), 0],
  46. scale=1.0,
  47. interpolation=interpolation,
  48. shear=[0.0, 0.0],
  49. fill=fill,
  50. )
  51. elif op_name == "TranslateY":
  52. img = F.affine(
  53. img,
  54. angle=0.0,
  55. translate=[0, int(magnitude)],
  56. scale=1.0,
  57. interpolation=interpolation,
  58. shear=[0.0, 0.0],
  59. fill=fill,
  60. )
  61. elif op_name == "Rotate":
  62. img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
  63. elif op_name == "Brightness":
  64. img = F.adjust_brightness(img, 1.0 + magnitude)
  65. elif op_name == "Color":
  66. img = F.adjust_saturation(img, 1.0 + magnitude)
  67. elif op_name == "Contrast":
  68. img = F.adjust_contrast(img, 1.0 + magnitude)
  69. elif op_name == "Sharpness":
  70. img = F.adjust_sharpness(img, 1.0 + magnitude)
  71. elif op_name == "Posterize":
  72. img = F.posterize(img, int(magnitude))
  73. elif op_name == "Solarize":
  74. img = F.solarize(img, magnitude)
  75. elif op_name == "AutoContrast":
  76. img = F.autocontrast(img)
  77. elif op_name == "Equalize":
  78. img = F.equalize(img)
  79. elif op_name == "Invert":
  80. img = F.invert(img)
  81. elif op_name == "Identity":
  82. pass
  83. else:
  84. raise ValueError(f"The provided operator {op_name} is not recognized.")
  85. return img
  86. class AutoAugmentPolicy(Enum):
  87. """AutoAugment policies learned on different datasets.
  88. Available policies are IMAGENET, CIFAR10 and SVHN.
  89. """
  90. IMAGENET = "imagenet"
  91. CIFAR10 = "cifar10"
  92. SVHN = "svhn"
  93. # FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
  94. class AutoAugment(torch.nn.Module):
  95. r"""AutoAugment data augmentation method based on
  96. `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
  97. If the image is torch Tensor, it should be of type torch.uint8, and it is expected
  98. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  99. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  100. Args:
  101. policy (AutoAugmentPolicy): Desired policy enum defined by
  102. :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
  103. interpolation (InterpolationMode): Desired interpolation enum defined by
  104. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  105. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  106. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  107. image. If given a number, the value is used for all bands respectively.
  108. """
  109. def __init__(
  110. self,
  111. policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
  112. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  113. fill: Optional[List[float]] = None,
  114. ) -> None:
  115. super().__init__()
  116. self.policy = policy
  117. self.interpolation = interpolation
  118. self.fill = fill
  119. self.policies = self._get_policies(policy)
  120. def _get_policies(
  121. self, policy: AutoAugmentPolicy
  122. ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
  123. if policy == AutoAugmentPolicy.IMAGENET:
  124. return [
  125. (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
  126. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  127. (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
  128. (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
  129. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  130. (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
  131. (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
  132. (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
  133. (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
  134. (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
  135. (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
  136. (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
  137. (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
  138. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  139. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  140. (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
  141. (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
  142. (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
  143. (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
  144. (("Color", 0.4, 0), ("Equalize", 0.6, None)),
  145. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  146. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  147. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  148. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  149. (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
  150. ]
  151. elif policy == AutoAugmentPolicy.CIFAR10:
  152. return [
  153. (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
  154. (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
  155. (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
  156. (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
  157. (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
  158. (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
  159. (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
  160. (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
  161. (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
  162. (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
  163. (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
  164. (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
  165. (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
  166. (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
  167. (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
  168. (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
  169. (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
  170. (("Color", 0.9, 9), ("Equalize", 0.6, None)),
  171. (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
  172. (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
  173. (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
  174. (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
  175. (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
  176. (("Equalize", 0.8, None), ("Invert", 0.1, None)),
  177. (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
  178. ]
  179. elif policy == AutoAugmentPolicy.SVHN:
  180. return [
  181. (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
  182. (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
  183. (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
  184. (("Invert", 0.9, None), ("Equalize", 0.6, None)),
  185. (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
  186. (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
  187. (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
  188. (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
  189. (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
  190. (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
  191. (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
  192. (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
  193. (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
  194. (("Invert", 0.9, None), ("Equalize", 0.6, None)),
  195. (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
  196. (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
  197. (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
  198. (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
  199. (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
  200. (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
  201. (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
  202. (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
  203. (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
  204. (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
  205. (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
  206. ]
  207. else:
  208. raise ValueError(f"The provided policy {policy} is not recognized.")
  209. def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
  210. return {
  211. # op_name: (magnitudes, signed)
  212. "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
  213. "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
  214. "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
  215. "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
  216. "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
  217. "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
  218. "Color": (torch.linspace(0.0, 0.9, num_bins), True),
  219. "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
  220. "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
  221. "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
  222. "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
  223. "AutoContrast": (torch.tensor(0.0), False),
  224. "Equalize": (torch.tensor(0.0), False),
  225. "Invert": (torch.tensor(0.0), False),
  226. }
  227. @staticmethod
  228. def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
  229. """Get parameters for autoaugment transformation
  230. Returns:
  231. params required by the autoaugment transformation
  232. """
  233. policy_id = int(torch.randint(transform_num, (1,)).item())
  234. probs = torch.rand((2,))
  235. signs = torch.randint(2, (2,))
  236. return policy_id, probs, signs
  237. def forward(self, img: Tensor) -> Tensor:
  238. """
  239. img (PIL Image or Tensor): Image to be transformed.
  240. Returns:
  241. PIL Image or Tensor: AutoAugmented image.
  242. """
  243. fill = self.fill
  244. channels, height, width = F.get_dimensions(img)
  245. if isinstance(img, Tensor):
  246. if isinstance(fill, (int, float)):
  247. fill = [float(fill)] * channels
  248. elif fill is not None:
  249. fill = [float(f) for f in fill]
  250. transform_id, probs, signs = self.get_params(len(self.policies))
  251. op_meta = self._augmentation_space(10, (height, width))
  252. for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
  253. if probs[i] <= p:
  254. magnitudes, signed = op_meta[op_name]
  255. magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
  256. if signed and signs[i] == 0:
  257. magnitude *= -1.0
  258. img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
  259. return img
  260. def __repr__(self) -> str:
  261. return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"
  262. class RandAugment(torch.nn.Module):
  263. r"""RandAugment data augmentation method based on
  264. `"RandAugment: Practical automated data augmentation with a reduced search space"
  265. <https://arxiv.org/abs/1909.13719>`_.
  266. If the image is torch Tensor, it should be of type torch.uint8, and it is expected
  267. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  268. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  269. Args:
  270. num_ops (int): Number of augmentation transformations to apply sequentially.
  271. magnitude (int): Magnitude for all the transformations.
  272. num_magnitude_bins (int): The number of different magnitude values.
  273. interpolation (InterpolationMode): Desired interpolation enum defined by
  274. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  275. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  276. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  277. image. If given a number, the value is used for all bands respectively.
  278. """
  279. def __init__(
  280. self,
  281. num_ops: int = 2,
  282. magnitude: int = 9,
  283. num_magnitude_bins: int = 31,
  284. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  285. fill: Optional[List[float]] = None,
  286. ) -> None:
  287. super().__init__()
  288. self.num_ops = num_ops
  289. self.magnitude = magnitude
  290. self.num_magnitude_bins = num_magnitude_bins
  291. self.interpolation = interpolation
  292. self.fill = fill
  293. def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
  294. return {
  295. # op_name: (magnitudes, signed)
  296. "Identity": (torch.tensor(0.0), False),
  297. "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
  298. "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
  299. "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
  300. "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
  301. "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
  302. "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
  303. "Color": (torch.linspace(0.0, 0.9, num_bins), True),
  304. "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
  305. "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
  306. "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
  307. "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
  308. "AutoContrast": (torch.tensor(0.0), False),
  309. "Equalize": (torch.tensor(0.0), False),
  310. }
  311. def forward(self, img: Tensor) -> Tensor:
  312. """
  313. img (PIL Image or Tensor): Image to be transformed.
  314. Returns:
  315. PIL Image or Tensor: Transformed image.
  316. """
  317. fill = self.fill
  318. channels, height, width = F.get_dimensions(img)
  319. if isinstance(img, Tensor):
  320. if isinstance(fill, (int, float)):
  321. fill = [float(fill)] * channels
  322. elif fill is not None:
  323. fill = [float(f) for f in fill]
  324. op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
  325. for _ in range(self.num_ops):
  326. op_index = int(torch.randint(len(op_meta), (1,)).item())
  327. op_name = list(op_meta.keys())[op_index]
  328. magnitudes, signed = op_meta[op_name]
  329. magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
  330. if signed and torch.randint(2, (1,)):
  331. magnitude *= -1.0
  332. img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
  333. return img
  334. def __repr__(self) -> str:
  335. s = (
  336. f"{self.__class__.__name__}("
  337. f"num_ops={self.num_ops}"
  338. f", magnitude={self.magnitude}"
  339. f", num_magnitude_bins={self.num_magnitude_bins}"
  340. f", interpolation={self.interpolation}"
  341. f", fill={self.fill}"
  342. f")"
  343. )
  344. return s
  345. class TrivialAugmentWide(torch.nn.Module):
  346. r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
  347. `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
  348. If the image is torch Tensor, it should be of type torch.uint8, and it is expected
  349. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  350. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  351. Args:
  352. num_magnitude_bins (int): The number of different magnitude values.
  353. interpolation (InterpolationMode): Desired interpolation enum defined by
  354. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  355. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  356. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  357. image. If given a number, the value is used for all bands respectively.
  358. """
  359. def __init__(
  360. self,
  361. num_magnitude_bins: int = 31,
  362. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  363. fill: Optional[List[float]] = None,
  364. ) -> None:
  365. super().__init__()
  366. self.num_magnitude_bins = num_magnitude_bins
  367. self.interpolation = interpolation
  368. self.fill = fill
  369. def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
  370. return {
  371. # op_name: (magnitudes, signed)
  372. "Identity": (torch.tensor(0.0), False),
  373. "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
  374. "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
  375. "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
  376. "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
  377. "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
  378. "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
  379. "Color": (torch.linspace(0.0, 0.99, num_bins), True),
  380. "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
  381. "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
  382. "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
  383. "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
  384. "AutoContrast": (torch.tensor(0.0), False),
  385. "Equalize": (torch.tensor(0.0), False),
  386. }
  387. def forward(self, img: Tensor) -> Tensor:
  388. """
  389. img (PIL Image or Tensor): Image to be transformed.
  390. Returns:
  391. PIL Image or Tensor: Transformed image.
  392. """
  393. fill = self.fill
  394. channels, height, width = F.get_dimensions(img)
  395. if isinstance(img, Tensor):
  396. if isinstance(fill, (int, float)):
  397. fill = [float(fill)] * channels
  398. elif fill is not None:
  399. fill = [float(f) for f in fill]
  400. op_meta = self._augmentation_space(self.num_magnitude_bins)
  401. op_index = int(torch.randint(len(op_meta), (1,)).item())
  402. op_name = list(op_meta.keys())[op_index]
  403. magnitudes, signed = op_meta[op_name]
  404. magnitude = (
  405. float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
  406. if magnitudes.ndim > 0
  407. else 0.0
  408. )
  409. if signed and torch.randint(2, (1,)):
  410. magnitude *= -1.0
  411. return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
  412. def __repr__(self) -> str:
  413. s = (
  414. f"{self.__class__.__name__}("
  415. f"num_magnitude_bins={self.num_magnitude_bins}"
  416. f", interpolation={self.interpolation}"
  417. f", fill={self.fill}"
  418. f")"
  419. )
  420. return s
  421. class AugMix(torch.nn.Module):
  422. r"""AugMix data augmentation method based on
  423. `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
  424. If the image is torch Tensor, it should be of type torch.uint8, and it is expected
  425. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  426. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  427. Args:
  428. severity (int): The severity of base augmentation operators. Default is ``3``.
  429. mixture_width (int): The number of augmentation chains. Default is ``3``.
  430. chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
  431. Default is ``-1``.
  432. alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
  433. all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
  434. interpolation (InterpolationMode): Desired interpolation enum defined by
  435. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  436. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  437. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  438. image. If given a number, the value is used for all bands respectively.
  439. """
  440. def __init__(
  441. self,
  442. severity: int = 3,
  443. mixture_width: int = 3,
  444. chain_depth: int = -1,
  445. alpha: float = 1.0,
  446. all_ops: bool = True,
  447. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  448. fill: Optional[List[float]] = None,
  449. ) -> None:
  450. super().__init__()
  451. self._PARAMETER_MAX = 10
  452. if not (1 <= severity <= self._PARAMETER_MAX):
  453. raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
  454. self.severity = severity
  455. self.mixture_width = mixture_width
  456. self.chain_depth = chain_depth
  457. self.alpha = alpha
  458. self.all_ops = all_ops
  459. self.interpolation = interpolation
  460. self.fill = fill
  461. def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
  462. s = {
  463. # op_name: (magnitudes, signed)
  464. "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
  465. "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
  466. "TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
  467. "TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
  468. "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
  469. "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
  470. "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
  471. "AutoContrast": (torch.tensor(0.0), False),
  472. "Equalize": (torch.tensor(0.0), False),
  473. }
  474. if self.all_ops:
  475. s.update(
  476. {
  477. "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
  478. "Color": (torch.linspace(0.0, 0.9, num_bins), True),
  479. "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
  480. "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
  481. }
  482. )
  483. return s
  484. @torch.jit.unused
  485. def _pil_to_tensor(self, img) -> Tensor:
  486. return F.pil_to_tensor(img)
  487. @torch.jit.unused
  488. def _tensor_to_pil(self, img: Tensor):
  489. return F.to_pil_image(img)
  490. def _sample_dirichlet(self, params: Tensor) -> Tensor:
  491. # Must be on a separate method so that we can overwrite it in tests.
  492. return torch._sample_dirichlet(params)
  493. def forward(self, orig_img: Tensor) -> Tensor:
  494. """
  495. img (PIL Image or Tensor): Image to be transformed.
  496. Returns:
  497. PIL Image or Tensor: Transformed image.
  498. """
  499. fill = self.fill
  500. channels, height, width = F.get_dimensions(orig_img)
  501. if isinstance(orig_img, Tensor):
  502. img = orig_img
  503. if isinstance(fill, (int, float)):
  504. fill = [float(fill)] * channels
  505. elif fill is not None:
  506. fill = [float(f) for f in fill]
  507. else:
  508. img = self._pil_to_tensor(orig_img)
  509. op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
  510. orig_dims = list(img.shape)
  511. batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
  512. batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
  513. # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
  514. # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
  515. m = self._sample_dirichlet(
  516. torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
  517. )
  518. # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
  519. combined_weights = self._sample_dirichlet(
  520. torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
  521. ) * m[:, 1].view([batch_dims[0], -1])
  522. mix = m[:, 0].view(batch_dims) * batch
  523. for i in range(self.mixture_width):
  524. aug = batch
  525. depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
  526. for _ in range(depth):
  527. op_index = int(torch.randint(len(op_meta), (1,)).item())
  528. op_name = list(op_meta.keys())[op_index]
  529. magnitudes, signed = op_meta[op_name]
  530. magnitude = (
  531. float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
  532. if magnitudes.ndim > 0
  533. else 0.0
  534. )
  535. if signed and torch.randint(2, (1,)):
  536. magnitude *= -1.0
  537. aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
  538. mix.add_(combined_weights[:, i].view(batch_dims) * aug)
  539. mix = mix.view(orig_dims).to(dtype=img.dtype)
  540. if not isinstance(orig_img, Tensor):
  541. return self._tensor_to_pil(mix)
  542. return mix
  543. def __repr__(self) -> str:
  544. s = (
  545. f"{self.__class__.__name__}("
  546. f"severity={self.severity}"
  547. f", mixture_width={self.mixture_width}"
  548. f", chain_depth={self.chain_depth}"
  549. f", alpha={self.alpha}"
  550. f", all_ops={self.all_ops}"
  551. f", interpolation={self.interpolation}"
  552. f", fill={self.fill}"
  553. f")"
  554. )
  555. return s