plot_transforms_illustrations.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """
  2. ==========================
  3. Illustration of transforms
  4. ==========================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_illustrations.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_illustrations.py>` to download the full example code.
  8. This example illustrates some of the various transforms available in :ref:`the
  9. torchvision.transforms.v2 module <transforms>`.
  10. """
  11. # %%
  12. # sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png"
  13. from PIL import Image
  14. from pathlib import Path
  15. import matplotlib.pyplot as plt
  16. import torch
  17. from torchvision.transforms import v2
  18. plt.rcParams["savefig.bbox"] = 'tight'
  19. # if you change the seed, make sure that the randomly-applied transforms
  20. # properly show that the image can be both transformed and *not* transformed!
  21. torch.manual_seed(0)
  22. # If you're trying to run that on collab, you can download the assets and the
  23. # helpers from https://github.com/pytorch/vision/tree/main/gallery/
  24. from helpers import plot
  25. orig_img = Image.open(Path('../assets') / 'astronaut.jpg')
  26. # %%
  27. # Geometric Transforms
  28. # --------------------
  29. # Geometric image transformation refers to the process of altering the geometric properties of an image,
  30. # such as its shape, size, orientation, or position.
  31. # It involves applying mathematical operations to the image pixels or coordinates to achieve the desired transformation.
  32. #
  33. # Pad
  34. # ~~~
  35. # The :class:`~torchvision.transforms.Pad` transform
  36. # (see also :func:`~torchvision.transforms.functional.pad`)
  37. # pads all image borders with some pixel values.
  38. padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
  39. plot([orig_img] + padded_imgs)
  40. # %%
  41. # Resize
  42. # ~~~~~~
  43. # The :class:`~torchvision.transforms.Resize` transform
  44. # (see also :func:`~torchvision.transforms.functional.resize`)
  45. # resizes an image.
  46. resized_imgs = [v2.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
  47. plot([orig_img] + resized_imgs)
  48. # %%
  49. # CenterCrop
  50. # ~~~~~~~~~~
  51. # The :class:`~torchvision.transforms.CenterCrop` transform
  52. # (see also :func:`~torchvision.transforms.functional.center_crop`)
  53. # crops the given image at the center.
  54. center_crops = [v2.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
  55. plot([orig_img] + center_crops)
  56. # %%
  57. # FiveCrop
  58. # ~~~~~~~~
  59. # The :class:`~torchvision.transforms.FiveCrop` transform
  60. # (see also :func:`~torchvision.transforms.functional.five_crop`)
  61. # crops the given image into four corners and the central crop.
  62. (top_left, top_right, bottom_left, bottom_right, center) = v2.FiveCrop(size=(100, 100))(orig_img)
  63. plot([orig_img] + [top_left, top_right, bottom_left, bottom_right, center])
  64. # %%
  65. # RandomPerspective
  66. # ~~~~~~~~~~~~~~~~~
  67. # The :class:`~torchvision.transforms.RandomPerspective` transform
  68. # (see also :func:`~torchvision.transforms.functional.perspective`)
  69. # performs random perspective transform on an image.
  70. perspective_transformer = v2.RandomPerspective(distortion_scale=0.6, p=1.0)
  71. perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
  72. plot([orig_img] + perspective_imgs)
  73. # %%
  74. # RandomRotation
  75. # ~~~~~~~~~~~~~~
  76. # The :class:`~torchvision.transforms.RandomRotation` transform
  77. # (see also :func:`~torchvision.transforms.functional.rotate`)
  78. # rotates an image with random angle.
  79. rotater = v2.RandomRotation(degrees=(0, 180))
  80. rotated_imgs = [rotater(orig_img) for _ in range(4)]
  81. plot([orig_img] + rotated_imgs)
  82. # %%
  83. # RandomAffine
  84. # ~~~~~~~~~~~~
  85. # The :class:`~torchvision.transforms.RandomAffine` transform
  86. # (see also :func:`~torchvision.transforms.functional.affine`)
  87. # performs random affine transform on an image.
  88. affine_transfomer = v2.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
  89. affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
  90. plot([orig_img] + affine_imgs)
  91. # %%
  92. # ElasticTransform
  93. # ~~~~~~~~~~~~~~~~
  94. # The :class:`~torchvision.transforms.ElasticTransform` transform
  95. # (see also :func:`~torchvision.transforms.functional.elastic_transform`)
  96. # Randomly transforms the morphology of objects in images and produces a
  97. # see-through-water-like effect.
  98. elastic_transformer = v2.ElasticTransform(alpha=250.0)
  99. transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)]
  100. plot([orig_img] + transformed_imgs)
  101. # %%
  102. # RandomCrop
  103. # ~~~~~~~~~~
  104. # The :class:`~torchvision.transforms.RandomCrop` transform
  105. # (see also :func:`~torchvision.transforms.functional.crop`)
  106. # crops an image at a random location.
  107. cropper = v2.RandomCrop(size=(128, 128))
  108. crops = [cropper(orig_img) for _ in range(4)]
  109. plot([orig_img] + crops)
  110. # %%
  111. # RandomResizedCrop
  112. # ~~~~~~~~~~~~~~~~~
  113. # The :class:`~torchvision.transforms.RandomResizedCrop` transform
  114. # (see also :func:`~torchvision.transforms.functional.resized_crop`)
  115. # crops an image at a random location, and then resizes the crop to a given
  116. # size.
  117. resize_cropper = v2.RandomResizedCrop(size=(32, 32))
  118. resized_crops = [resize_cropper(orig_img) for _ in range(4)]
  119. plot([orig_img] + resized_crops)
  120. # %%
  121. # Photometric Transforms
  122. # ----------------------
  123. # Photometric image transformation refers to the process of modifying the photometric properties of an image,
  124. # such as its brightness, contrast, color, or tone.
  125. # These transformations are applied to change the visual appearance of an image
  126. # while preserving its geometric structure.
  127. #
  128. # Except :class:`~torchvision.transforms.Grayscale`, the following transforms are random,
  129. # which means that the same transform
  130. # instance will produce different result each time it transforms a given image.
  131. #
  132. # Grayscale
  133. # ~~~~~~~~~
  134. # The :class:`~torchvision.transforms.Grayscale` transform
  135. # (see also :func:`~torchvision.transforms.functional.to_grayscale`)
  136. # converts an image to grayscale
  137. gray_img = v2.Grayscale()(orig_img)
  138. plot([orig_img, gray_img], cmap='gray')
  139. # %%
  140. # ColorJitter
  141. # ~~~~~~~~~~~
  142. # The :class:`~torchvision.transforms.ColorJitter` transform
  143. # randomly changes the brightness, contrast, saturation, hue, and other properties of an image.
  144. jitter = v2.ColorJitter(brightness=.5, hue=.3)
  145. jittered_imgs = [jitter(orig_img) for _ in range(4)]
  146. plot([orig_img] + jittered_imgs)
  147. # %%
  148. # GaussianBlur
  149. # ~~~~~~~~~~~~
  150. # The :class:`~torchvision.transforms.GaussianBlur` transform
  151. # (see also :func:`~torchvision.transforms.functional.gaussian_blur`)
  152. # performs gaussian blur transform on an image.
  153. blurrer = v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.))
  154. blurred_imgs = [blurrer(orig_img) for _ in range(4)]
  155. plot([orig_img] + blurred_imgs)
  156. # %%
  157. # RandomInvert
  158. # ~~~~~~~~~~~~
  159. # The :class:`~torchvision.transforms.RandomInvert` transform
  160. # (see also :func:`~torchvision.transforms.functional.invert`)
  161. # randomly inverts the colors of the given image.
  162. inverter = v2.RandomInvert()
  163. invertered_imgs = [inverter(orig_img) for _ in range(4)]
  164. plot([orig_img] + invertered_imgs)
  165. # %%
  166. # RandomPosterize
  167. # ~~~~~~~~~~~~~~~
  168. # The :class:`~torchvision.transforms.RandomPosterize` transform
  169. # (see also :func:`~torchvision.transforms.functional.posterize`)
  170. # randomly posterizes the image by reducing the number of bits
  171. # of each color channel.
  172. posterizer = v2.RandomPosterize(bits=2)
  173. posterized_imgs = [posterizer(orig_img) for _ in range(4)]
  174. plot([orig_img] + posterized_imgs)
  175. # %%
  176. # RandomSolarize
  177. # ~~~~~~~~~~~~~~
  178. # The :class:`~torchvision.transforms.RandomSolarize` transform
  179. # (see also :func:`~torchvision.transforms.functional.solarize`)
  180. # randomly solarizes the image by inverting all pixel values above
  181. # the threshold.
  182. solarizer = v2.RandomSolarize(threshold=192.0)
  183. solarized_imgs = [solarizer(orig_img) for _ in range(4)]
  184. plot([orig_img] + solarized_imgs)
  185. # %%
  186. # RandomAdjustSharpness
  187. # ~~~~~~~~~~~~~~~~~~~~~
  188. # The :class:`~torchvision.transforms.RandomAdjustSharpness` transform
  189. # (see also :func:`~torchvision.transforms.functional.adjust_sharpness`)
  190. # randomly adjusts the sharpness of the given image.
  191. sharpness_adjuster = v2.RandomAdjustSharpness(sharpness_factor=2)
  192. sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
  193. plot([orig_img] + sharpened_imgs)
  194. # %%
  195. # RandomAutocontrast
  196. # ~~~~~~~~~~~~~~~~~~
  197. # The :class:`~torchvision.transforms.RandomAutocontrast` transform
  198. # (see also :func:`~torchvision.transforms.functional.autocontrast`)
  199. # randomly applies autocontrast to the given image.
  200. autocontraster = v2.RandomAutocontrast()
  201. autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
  202. plot([orig_img] + autocontrasted_imgs)
  203. # %%
  204. # RandomEqualize
  205. # ~~~~~~~~~~~~~~
  206. # The :class:`~torchvision.transforms.RandomEqualize` transform
  207. # (see also :func:`~torchvision.transforms.functional.equalize`)
  208. # randomly equalizes the histogram of the given image.
  209. equalizer = v2.RandomEqualize()
  210. equalized_imgs = [equalizer(orig_img) for _ in range(4)]
  211. plot([orig_img] + equalized_imgs)
  212. # %%
  213. # Augmentation Transforms
  214. # -----------------------
  215. # The following transforms are combinations of multiple transforms,
  216. # either geometric or photometric, or both.
  217. #
  218. # AutoAugment
  219. # ~~~~~~~~~~~
  220. # The :class:`~torchvision.transforms.AutoAugment` transform
  221. # automatically augments data based on a given auto-augmentation policy.
  222. # See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
  223. policies = [v2.AutoAugmentPolicy.CIFAR10, v2.AutoAugmentPolicy.IMAGENET, v2.AutoAugmentPolicy.SVHN]
  224. augmenters = [v2.AutoAugment(policy) for policy in policies]
  225. imgs = [
  226. [augmenter(orig_img) for _ in range(4)]
  227. for augmenter in augmenters
  228. ]
  229. row_title = [str(policy).split('.')[-1] for policy in policies]
  230. plot([[orig_img] + row for row in imgs], row_title=row_title)
  231. # %%
  232. # RandAugment
  233. # ~~~~~~~~~~~
  234. # The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment.
  235. augmenter = v2.RandAugment()
  236. imgs = [augmenter(orig_img) for _ in range(4)]
  237. plot([orig_img] + imgs)
  238. # %%
  239. # TrivialAugmentWide
  240. # ~~~~~~~~~~~~~~~~~~
  241. # The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment.
  242. # However, instead of transforming an image multiple times, it transforms an image only once
  243. # using a random transform from a given list with a random strength number.
  244. augmenter = v2.TrivialAugmentWide()
  245. imgs = [augmenter(orig_img) for _ in range(4)]
  246. plot([orig_img] + imgs)
  247. # %%
  248. # AugMix
  249. # ~~~~~~
  250. # The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image.
  251. augmenter = v2.AugMix()
  252. imgs = [augmenter(orig_img) for _ in range(4)]
  253. plot([orig_img] + imgs)
  254. # %%
  255. # Randomly-applied Transforms
  256. # ---------------------------
  257. #
  258. # The following transforms are randomly-applied given a probability ``p``. That is, given ``p = 0.5``,
  259. # there is a 50% chance to return the original image, and a 50% chance to return the transformed image,
  260. # even when called with the same transform instance!
  261. #
  262. # RandomHorizontalFlip
  263. # ~~~~~~~~~~~~~~~~~~~~
  264. # The :class:`~torchvision.transforms.RandomHorizontalFlip` transform
  265. # (see also :func:`~torchvision.transforms.functional.hflip`)
  266. # performs horizontal flip of an image, with a given probability.
  267. hflipper = v2.RandomHorizontalFlip(p=0.5)
  268. transformed_imgs = [hflipper(orig_img) for _ in range(4)]
  269. plot([orig_img] + transformed_imgs)
  270. # %%
  271. # RandomVerticalFlip
  272. # ~~~~~~~~~~~~~~~~~~
  273. # The :class:`~torchvision.transforms.RandomVerticalFlip` transform
  274. # (see also :func:`~torchvision.transforms.functional.vflip`)
  275. # performs vertical flip of an image, with a given probability.
  276. vflipper = v2.RandomVerticalFlip(p=0.5)
  277. transformed_imgs = [vflipper(orig_img) for _ in range(4)]
  278. plot([orig_img] + transformed_imgs)
  279. # %%
  280. # RandomApply
  281. # ~~~~~~~~~~~
  282. # The :class:`~torchvision.transforms.RandomApply` transform
  283. # randomly applies a list of transforms, with a given probability.
  284. applier = v2.RandomApply(transforms=[v2.RandomCrop(size=(64, 64))], p=0.5)
  285. transformed_imgs = [applier(orig_img) for _ in range(4)]
  286. plot([orig_img] + transformed_imgs)