_augment.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import math
  2. import numbers
  3. import warnings
  4. from typing import Any, Callable, Dict, List, Tuple
  5. import PIL.Image
  6. import torch
  7. from torch.nn.functional import one_hot
  8. from torch.utils._pytree import tree_flatten, tree_unflatten
  9. from torchvision import transforms as _transforms, tv_tensors
  10. from torchvision.transforms.v2 import functional as F
  11. from ._transform import _RandomApplyTransform, Transform
  12. from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
  13. class RandomErasing(_RandomApplyTransform):
  14. """[BETA] Randomly select a rectangle region in the input image or video and erase its pixels.
  15. .. v2betastatus:: RandomErasing transform
  16. This transform does not support PIL Image.
  17. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
  18. Args:
  19. p (float, optional): probability that the random erasing operation will be performed.
  20. scale (tuple of float, optional): range of proportion of erased area against input image.
  21. ratio (tuple of float, optional): range of aspect ratio of erased area.
  22. value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
  23. erase all pixels. If a tuple of length 3, it is used to erase
  24. R, G, B channels respectively.
  25. If a str of 'random', erasing each pixel with random values.
  26. inplace (bool, optional): boolean to make this transform inplace. Default set to False.
  27. Returns:
  28. Erased input.
  29. Example:
  30. >>> from torchvision.transforms import v2 as transforms
  31. >>>
  32. >>> transform = transforms.Compose([
  33. >>> transforms.RandomHorizontalFlip(),
  34. >>> transforms.PILToTensor(),
  35. >>> transforms.ConvertImageDtype(torch.float),
  36. >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  37. >>> transforms.RandomErasing(),
  38. >>> ])
  39. """
  40. _v1_transform_cls = _transforms.RandomErasing
  41. def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
  42. return dict(
  43. super()._extract_params_for_v1_transform(),
  44. value="random" if self.value is None else self.value,
  45. )
  46. def __init__(
  47. self,
  48. p: float = 0.5,
  49. scale: Tuple[float, float] = (0.02, 0.33),
  50. ratio: Tuple[float, float] = (0.3, 3.3),
  51. value: float = 0.0,
  52. inplace: bool = False,
  53. ):
  54. super().__init__(p=p)
  55. if not isinstance(value, (numbers.Number, str, tuple, list)):
  56. raise TypeError("Argument value should be either a number or str or a sequence")
  57. if isinstance(value, str) and value != "random":
  58. raise ValueError("If value is str, it should be 'random'")
  59. if not isinstance(scale, (tuple, list)):
  60. raise TypeError("Scale should be a sequence")
  61. if not isinstance(ratio, (tuple, list)):
  62. raise TypeError("Ratio should be a sequence")
  63. if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
  64. warnings.warn("Scale and ratio should be of kind (min, max)")
  65. if scale[0] < 0 or scale[1] > 1:
  66. raise ValueError("Scale should be between 0 and 1")
  67. self.scale = scale
  68. self.ratio = ratio
  69. if isinstance(value, (int, float)):
  70. self.value = [float(value)]
  71. elif isinstance(value, str):
  72. self.value = None
  73. elif isinstance(value, (list, tuple)):
  74. self.value = [float(v) for v in value]
  75. else:
  76. self.value = value
  77. self.inplace = inplace
  78. self._log_ratio = torch.log(torch.tensor(self.ratio))
  79. def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
  80. if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
  81. warnings.warn(
  82. f"{type(self).__name__}() is currently passing through inputs of type "
  83. f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
  84. )
  85. return super()._call_kernel(functional, inpt, *args, **kwargs)
  86. def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
  87. img_c, img_h, img_w = query_chw(flat_inputs)
  88. if self.value is not None and not (len(self.value) in (1, img_c)):
  89. raise ValueError(
  90. f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
  91. )
  92. area = img_h * img_w
  93. log_ratio = self._log_ratio
  94. for _ in range(10):
  95. erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
  96. aspect_ratio = torch.exp(
  97. torch.empty(1).uniform_(
  98. log_ratio[0], # type: ignore[arg-type]
  99. log_ratio[1], # type: ignore[arg-type]
  100. )
  101. ).item()
  102. h = int(round(math.sqrt(erase_area * aspect_ratio)))
  103. w = int(round(math.sqrt(erase_area / aspect_ratio)))
  104. if not (h < img_h and w < img_w):
  105. continue
  106. if self.value is None:
  107. v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
  108. else:
  109. v = torch.tensor(self.value)[:, None, None]
  110. i = torch.randint(0, img_h - h + 1, size=(1,)).item()
  111. j = torch.randint(0, img_w - w + 1, size=(1,)).item()
  112. break
  113. else:
  114. i, j, h, w, v = 0, 0, img_h, img_w, None
  115. return dict(i=i, j=j, h=h, w=w, v=v)
  116. def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
  117. if params["v"] is not None:
  118. inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
  119. return inpt
  120. class _BaseMixUpCutMix(Transform):
  121. def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
  122. super().__init__()
  123. self.alpha = float(alpha)
  124. self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
  125. self.num_classes = num_classes
  126. self._labels_getter = _parse_labels_getter(labels_getter)
  127. def forward(self, *inputs):
  128. inputs = inputs if len(inputs) > 1 else inputs[0]
  129. flat_inputs, spec = tree_flatten(inputs)
  130. needs_transform_list = self._needs_transform_list(flat_inputs)
  131. if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
  132. raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
  133. labels = self._labels_getter(inputs)
  134. if not isinstance(labels, torch.Tensor):
  135. raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
  136. elif labels.ndim != 1:
  137. raise ValueError(
  138. f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
  139. )
  140. params = {
  141. "labels": labels,
  142. "batch_size": labels.shape[0],
  143. **self._get_params(
  144. [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
  145. ),
  146. }
  147. # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
  148. # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
  149. needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
  150. flat_outputs = [
  151. self._transform(inpt, params) if needs_transform else inpt
  152. for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
  153. ]
  154. return tree_unflatten(flat_outputs, spec)
  155. def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
  156. expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
  157. if inpt.ndim != expected_num_dims:
  158. raise ValueError(
  159. f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
  160. )
  161. if inpt.shape[0] != batch_size:
  162. raise ValueError(
  163. f"The batch size of the image or video does not match the batch size of the labels: "
  164. f"{inpt.shape[0]} != {batch_size}."
  165. )
  166. def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
  167. label = one_hot(label, num_classes=self.num_classes)
  168. if not label.dtype.is_floating_point:
  169. label = label.float()
  170. return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
  171. class MixUp(_BaseMixUpCutMix):
  172. """[BETA] Apply MixUp to the provided batch of images and labels.
  173. .. v2betastatus:: MixUp transform
  174. Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
  175. .. note::
  176. This transform is meant to be used on **batches** of samples, not
  177. individual images. See
  178. :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
  179. examples.
  180. The sample pairing is deterministic and done by matching consecutive
  181. samples in the batch, so the batch needs to be shuffled (this is an
  182. implementation detail, not a guaranteed convention.)
  183. In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
  184. into a tensor of shape ``(batch_size, num_classes)``.
  185. Args:
  186. alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
  187. num_classes (int): number of classes in the batch. Used for one-hot-encoding.
  188. labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
  189. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
  190. common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
  191. It can also be a callable that takes the same input as the transform, and returns the labels.
  192. """
  193. def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
  194. return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
  195. def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
  196. lam = params["lam"]
  197. if inpt is params["labels"]:
  198. return self._mixup_label(inpt, lam=lam)
  199. elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
  200. self._check_image_or_video(inpt, batch_size=params["batch_size"])
  201. output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
  202. if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
  203. output = tv_tensors.wrap(output, like=inpt)
  204. return output
  205. else:
  206. return inpt
  207. class CutMix(_BaseMixUpCutMix):
  208. """[BETA] Apply CutMix to the provided batch of images and labels.
  209. .. v2betastatus:: CutMix transform
  210. Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
  211. <https://arxiv.org/abs/1905.04899>`_.
  212. .. note::
  213. This transform is meant to be used on **batches** of samples, not
  214. individual images. See
  215. :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
  216. examples.
  217. The sample pairing is deterministic and done by matching consecutive
  218. samples in the batch, so the batch needs to be shuffled (this is an
  219. implementation detail, not a guaranteed convention.)
  220. In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
  221. into a tensor of shape ``(batch_size, num_classes)``.
  222. Args:
  223. alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
  224. num_classes (int): number of classes in the batch. Used for one-hot-encoding.
  225. labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
  226. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
  227. common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
  228. It can also be a callable that takes the same input as the transform, and returns the labels.
  229. """
  230. def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
  231. lam = float(self._dist.sample(())) # type: ignore[arg-type]
  232. H, W = query_size(flat_inputs)
  233. r_x = torch.randint(W, size=(1,))
  234. r_y = torch.randint(H, size=(1,))
  235. r = 0.5 * math.sqrt(1.0 - lam)
  236. r_w_half = int(r * W)
  237. r_h_half = int(r * H)
  238. x1 = int(torch.clamp(r_x - r_w_half, min=0))
  239. y1 = int(torch.clamp(r_y - r_h_half, min=0))
  240. x2 = int(torch.clamp(r_x + r_w_half, max=W))
  241. y2 = int(torch.clamp(r_y + r_h_half, max=H))
  242. box = (x1, y1, x2, y2)
  243. lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
  244. return dict(box=box, lam_adjusted=lam_adjusted)
  245. def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
  246. if inpt is params["labels"]:
  247. return self._mixup_label(inpt, lam=params["lam_adjusted"])
  248. elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
  249. self._check_image_or_video(inpt, batch_size=params["batch_size"])
  250. x1, y1, x2, y2 = params["box"]
  251. rolled = inpt.roll(1, 0)
  252. output = inpt.clone()
  253. output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
  254. if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
  255. output = tv_tensors.wrap(output, like=inpt)
  256. return output
  257. else:
  258. return inpt