plot_cutmix_mixup.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """
  2. ===========================
  3. How to use CutMix and MixUp
  4. ===========================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_cutmix_mixup.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_cutmix_mixup.py>` to download the full example code.
  8. :class:`~torchvision.transforms.v2.CutMix` and
  9. :class:`~torchvision.transforms.v2.MixUp` are popular augmentation strategies
  10. that can improve classification accuracy.
  11. These transforms are slightly different from the rest of the Torchvision
  12. transforms, because they expect
  13. **batches** of samples as input, not individual images. In this example we'll
  14. explain how to use them: after the ``DataLoader``, or as part of a collation
  15. function.
  16. """
  17. # %%
  18. import torch
  19. from torchvision.datasets import FakeData
  20. from torchvision.transforms import v2
  21. NUM_CLASSES = 100
  22. # %%
  23. # Pre-processing pipeline
  24. # -----------------------
  25. #
  26. # We'll use a simple but typical image classification pipeline:
  27. preproc = v2.Compose([
  28. v2.PILToTensor(),
  29. v2.RandomResizedCrop(size=(224, 224), antialias=True),
  30. v2.RandomHorizontalFlip(p=0.5),
  31. v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
  32. v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet
  33. ])
  34. dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)
  35. img, label = dataset[0]
  36. print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
  37. # %%
  38. #
  39. # One important thing to note is that neither CutMix nor MixUp are part of this
  40. # pre-processing pipeline. We'll add them a bit later once we define the
  41. # DataLoader. Just as a refresher, this is what the DataLoader and training loop
  42. # would look like if we weren't using CutMix or MixUp:
  43. from torch.utils.data import DataLoader
  44. dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
  45. for images, labels in dataloader:
  46. print(f"{images.shape = }, {labels.shape = }")
  47. print(labels.dtype)
  48. # <rest of the training loop here>
  49. break
  50. # %%
  51. # %%
  52. # Where to use MixUp and CutMix
  53. # -----------------------------
  54. #
  55. # After the DataLoader
  56. # ^^^^^^^^^^^^^^^^^^^^
  57. #
  58. # Now let's add CutMix and MixUp. The simplest way to do this right after the
  59. # DataLoader: the Dataloader has already batched the images and labels for us,
  60. # and this is exactly what these transforms expect as input:
  61. dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
  62. cutmix = v2.CutMix(num_classes=NUM_CLASSES)
  63. mixup = v2.MixUp(num_classes=NUM_CLASSES)
  64. cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
  65. for images, labels in dataloader:
  66. print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
  67. images, labels = cutmix_or_mixup(images, labels)
  68. print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")
  69. # <rest of the training loop here>
  70. break
  71. # %%
  72. #
  73. # Note how the labels were also transformed: we went from a batched label of
  74. # shape (batch_size,) to a tensor of shape (batch_size, num_classes). The
  75. # transformed labels can still be passed as-is to a loss function like
  76. # :func:`torch.nn.functional.cross_entropy`.
  77. #
  78. # As part of the collation function
  79. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  80. #
  81. # Passing the transforms after the DataLoader is the simplest way to use CutMix
  82. # and MixUp, but one disadvantage is that it does not take advantage of the
  83. # DataLoader multi-processing. For that, we can pass those transforms as part of
  84. # the collation function (refer to the `PyTorch docs
  85. # <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn
  86. # more about collation).
  87. from torch.utils.data import default_collate
  88. def collate_fn(batch):
  89. return cutmix_or_mixup(*default_collate(batch))
  90. dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)
  91. for images, labels in dataloader:
  92. print(f"{images.shape = }, {labels.shape = }")
  93. # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
  94. # <rest of the training loop here>
  95. break
  96. # %%
  97. # Non-standard input format
  98. # -------------------------
  99. #
  100. # So far we've used a typical sample structure where we pass ``(images,
  101. # labels)`` as inputs. MixUp and CutMix will magically work by default with most
  102. # common sample structures: tuples where the second parameter is a tensor label,
  103. # or dict with a "label[s]" key. Look at the documentation of the
  104. # ``labels_getter`` parameter for more details.
  105. #
  106. # If your samples have a different structure, you can still use CutMix and MixUp
  107. # by passing a callable to the ``labels_getter`` parameter. For example:
  108. batch = {
  109. "imgs": torch.rand(4, 3, 224, 224),
  110. "target": {
  111. "classes": torch.randint(0, NUM_CLASSES, size=(4,)),
  112. "some_other_key": "this is going to be passed-through"
  113. }
  114. }
  115. def labels_getter(batch):
  116. return batch["target"]["classes"]
  117. out = v2.CutMix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
  118. print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")