_transform.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from __future__ import annotations
  2. import enum
  3. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
  4. import PIL.Image
  5. import torch
  6. from torch import nn
  7. from torch.utils._pytree import tree_flatten, tree_unflatten
  8. from torchvision import tv_tensors
  9. from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
  10. from torchvision.utils import _log_api_usage_once
  11. from .functional._utils import _get_kernel
  12. class Transform(nn.Module):
  13. # Class attribute defining transformed types. Other types are passed-through without any transformation
  14. # We support both Types and callables that are able to do further checks on the type of the input.
  15. _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
  16. def __init__(self) -> None:
  17. super().__init__()
  18. _log_api_usage_once(self)
  19. def _check_inputs(self, flat_inputs: List[Any]) -> None:
  20. pass
  21. def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
  22. return dict()
  23. def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
  24. kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
  25. return kernel(inpt, *args, **kwargs)
  26. def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
  27. raise NotImplementedError
  28. def forward(self, *inputs: Any) -> Any:
  29. flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
  30. self._check_inputs(flat_inputs)
  31. needs_transform_list = self._needs_transform_list(flat_inputs)
  32. params = self._get_params(
  33. [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
  34. )
  35. flat_outputs = [
  36. self._transform(inpt, params) if needs_transform else inpt
  37. for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
  38. ]
  39. return tree_unflatten(flat_outputs, spec)
  40. def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
  41. # Below is a heuristic on how to deal with pure tensor inputs:
  42. # 1. Pure tensors, i.e. tensors that are not a tv_tensor, are passed through if there is an explicit image
  43. # (`tv_tensors.Image` or `PIL.Image.Image`) or video (`tv_tensors.Video`) in the sample.
  44. # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
  45. # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
  46. # of `tree_flatten`, which recurses depth-first through the input.
  47. #
  48. # This heuristic stems from two requirements:
  49. # 1. We need to keep BC for single input pure tensors and treat them as images.
  50. # 2. We don't want to treat all pure tensors as images, because some datasets like `CelebA` or `Widerface`
  51. # return supplemental numerical data as tensors that cannot be transformed as images.
  52. #
  53. # The heuristic should work well for most people in practice. The only case where it doesn't is if someone
  54. # tries to transform multiple pure tensors at the same time, expecting them all to be treated as images.
  55. # However, this case wasn't supported by transforms v1 either, so there is no BC concern.
  56. needs_transform_list = []
  57. transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
  58. for inpt in flat_inputs:
  59. needs_transform = True
  60. if not check_type(inpt, self._transformed_types):
  61. needs_transform = False
  62. elif is_pure_tensor(inpt):
  63. if transform_pure_tensor:
  64. transform_pure_tensor = False
  65. else:
  66. needs_transform = False
  67. needs_transform_list.append(needs_transform)
  68. return needs_transform_list
  69. def extra_repr(self) -> str:
  70. extra = []
  71. for name, value in self.__dict__.items():
  72. if name.startswith("_") or name == "training":
  73. continue
  74. if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
  75. continue
  76. extra.append(f"{name}={value}")
  77. return ", ".join(extra)
  78. # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
  79. # 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
  80. # the v2 transform. See `__init_subclass__` for details.
  81. # 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
  82. # for details.
  83. _v1_transform_cls: Optional[Type[nn.Module]] = None
  84. def __init_subclass__(cls) -> None:
  85. # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
  86. # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
  87. if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
  88. cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined]
  89. def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
  90. # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
  91. # v2 transform instance. It extracts all available public attributes that are specific to that transform and
  92. # not `nn.Module` in general.
  93. # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
  94. # if the v2 transform introduced new parameters that are not support by the v1 transform.
  95. common_attrs = nn.Module().__dict__.keys()
  96. return {
  97. attr: value
  98. for attr, value in self.__dict__.items()
  99. if not attr.startswith("_") and attr not in common_attrs
  100. }
  101. def __prepare_scriptable__(self) -> nn.Module:
  102. # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
  103. # value is used for scripting over the original object that should have been scripted. Since the v1 transforms
  104. # are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
  105. # equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
  106. # is around.
  107. if self._v1_transform_cls is None:
  108. raise RuntimeError(
  109. f"Transform {type(self).__name__} cannot be JIT scripted. "
  110. "torchscript is only supported for backward compatibility with transforms "
  111. "which are already in torchvision.transforms. "
  112. "For torchscript support (on tensors only), you can use the functional API instead."
  113. )
  114. return self._v1_transform_cls(**self._extract_params_for_v1_transform())
  115. class _RandomApplyTransform(Transform):
  116. def __init__(self, p: float = 0.5) -> None:
  117. if not (0.0 <= p <= 1.0):
  118. raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
  119. super().__init__()
  120. self.p = p
  121. def forward(self, *inputs: Any) -> Any:
  122. # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
  123. # early afterwards in case the random check triggers. The same result could be achieved by calling
  124. # `super().forward()` after the random check, but that would call `self._check_inputs` twice.
  125. inputs = inputs if len(inputs) > 1 else inputs[0]
  126. flat_inputs, spec = tree_flatten(inputs)
  127. self._check_inputs(flat_inputs)
  128. if torch.rand(1) >= self.p:
  129. return inputs
  130. needs_transform_list = self._needs_transform_list(flat_inputs)
  131. params = self._get_params(
  132. [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
  133. )
  134. flat_outputs = [
  135. self._transform(inpt, params) if needs_transform else inpt
  136. for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
  137. ]
  138. return tree_unflatten(flat_outputs, spec)