plot_custom_transforms.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. """
  2. ===================================
  3. How to write your own v2 transforms
  4. ===================================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_transforms.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_transforms.py>` to download the full example code.
  8. This guide explains how to write transforms that are compatible with the
  9. torchvision transforms V2 API.
  10. """
  11. # %%
  12. import torch
  13. from torchvision import tv_tensors
  14. from torchvision.transforms import v2
  15. # %%
  16. # Just create a ``nn.Module`` and override the ``forward`` method
  17. # ===============================================================
  18. #
  19. # In most cases, this is all you're going to need, as long as you already know
  20. # the structure of the input that your transform will expect. For example if
  21. # you're just doing image classification, your transform will typically accept a
  22. # single image as input, or a ``(img, label)`` input. So you can just hard-code
  23. # your ``forward`` method to accept just that, e.g.
  24. #
  25. # .. code:: python
  26. #
  27. # class MyCustomTransform(torch.nn.Module):
  28. # def forward(self, img, label):
  29. # # Do some transformations
  30. # return new_img, new_label
  31. #
  32. # .. note::
  33. #
  34. # This means that if you have a custom transform that is already compatible
  35. # with the V1 transforms (those in ``torchvision.transforms``), it will
  36. # still work with the V2 transforms without any change!
  37. #
  38. # We will illustrate this more completely below with a typical detection case,
  39. # where our samples are just images, bounding boxes and labels:
  40. class MyCustomTransform(torch.nn.Module):
  41. def forward(self, img, bboxes, label): # we assume inputs are always structured like this
  42. print(
  43. f"I'm transforming an image of shape {img.shape} "
  44. f"with bboxes = {bboxes}\n{label = }"
  45. )
  46. # Do some transformations. Here, we're just passing though the input
  47. return img, bboxes, label
  48. transforms = v2.Compose([
  49. MyCustomTransform(),
  50. v2.RandomResizedCrop((224, 224), antialias=True),
  51. v2.RandomHorizontalFlip(p=1),
  52. v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
  53. ])
  54. H, W = 256, 256
  55. img = torch.rand(3, H, W)
  56. bboxes = tv_tensors.BoundingBoxes(
  57. torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
  58. format="XYXY",
  59. canvas_size=(H, W)
  60. )
  61. label = 3
  62. out_img, out_bboxes, out_label = transforms(img, bboxes, label)
  63. # %%
  64. print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
  65. # %%
  66. # .. note::
  67. # While working with TVTensor classes in your code, make sure to
  68. # familiarize yourself with this section:
  69. # :ref:`tv_tensor_unwrapping_behaviour`
  70. #
  71. # Supporting arbitrary input structures
  72. # =====================================
  73. #
  74. # In the section above, we have assumed that you already know the structure of
  75. # your inputs and that you're OK with hard-coding this expected structure in
  76. # your code. If you want your custom transforms to be as flexible as possible,
  77. # this can be a bit limiting.
  78. #
  79. # A key feature of the builtin Torchvision V2 transforms is that they can accept
  80. # arbitrary input structure and return the same structure as output (with
  81. # transformed entries). For example, transforms can accept a single image, or a
  82. # tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
  83. structured_input = {
  84. "img": img,
  85. "annotations": (bboxes, label),
  86. "something_that_will_be_ignored": (1, "hello")
  87. }
  88. structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
  89. assert isinstance(structured_output, dict)
  90. assert structured_output["something_that_will_be_ignored"] == (1, "hello")
  91. print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
  92. # %%
  93. # If you want to reproduce this behavior in your own transform, we invite you to
  94. # look at our `code
  95. # <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
  96. # and adapt it to your needs.
  97. #
  98. # In brief, the core logic is to unpack the input into a flat list using `pytree
  99. # <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
  100. # then transform only the entries that can be transformed (the decision is made
  101. # based on the **class** of the entries, as all TVTensors are
  102. # tensor-subclasses) plus some custom logic that is out of score here - check the
  103. # code for details. The (potentially transformed) entries are then repacked and
  104. # returned, in the same structure as the input.
  105. #
  106. # We do not provide public dev-facing tools to achieve that at this time, but if
  107. # this is something that would be valuable to you, please let us know by opening
  108. # an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.