_container.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from typing import Any, Callable, Dict, List, Optional, Sequence, Union
  2. import torch
  3. from torch import nn
  4. from torchvision import transforms as _transforms
  5. from torchvision.transforms.v2 import Transform
  6. class Compose(Transform):
  7. """[BETA] Composes several transforms together.
  8. .. v2betastatus:: Compose transform
  9. This transform does not support torchscript.
  10. Please, see the note below.
  11. Args:
  12. transforms (list of ``Transform`` objects): list of transforms to compose.
  13. Example:
  14. >>> transforms.Compose([
  15. >>> transforms.CenterCrop(10),
  16. >>> transforms.PILToTensor(),
  17. >>> transforms.ConvertImageDtype(torch.float),
  18. >>> ])
  19. .. note::
  20. In order to script the transformations, please use ``torch.nn.Sequential`` as below.
  21. >>> transforms = torch.nn.Sequential(
  22. >>> transforms.CenterCrop(10),
  23. >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  24. >>> )
  25. >>> scripted_transforms = torch.jit.script(transforms)
  26. Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
  27. `lambda` functions or ``PIL.Image``.
  28. """
  29. def __init__(self, transforms: Sequence[Callable]) -> None:
  30. super().__init__()
  31. if not isinstance(transforms, Sequence):
  32. raise TypeError("Argument transforms should be a sequence of callables")
  33. elif not transforms:
  34. raise ValueError("Pass at least one transform")
  35. self.transforms = transforms
  36. def forward(self, *inputs: Any) -> Any:
  37. needs_unpacking = len(inputs) > 1
  38. for transform in self.transforms:
  39. outputs = transform(*inputs)
  40. inputs = outputs if needs_unpacking else (outputs,)
  41. return outputs
  42. def extra_repr(self) -> str:
  43. format_string = []
  44. for t in self.transforms:
  45. format_string.append(f" {t}")
  46. return "\n".join(format_string)
  47. class RandomApply(Transform):
  48. """[BETA] Apply randomly a list of transformations with a given probability.
  49. .. v2betastatus:: RandomApply transform
  50. .. note::
  51. In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
  52. transforms as shown below:
  53. >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
  54. >>> transforms.ColorJitter(),
  55. >>> ]), p=0.3)
  56. >>> scripted_transforms = torch.jit.script(transforms)
  57. Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
  58. `lambda` functions or ``PIL.Image``.
  59. Args:
  60. transforms (sequence or torch.nn.Module): list of transformations
  61. p (float): probability of applying the list of transforms
  62. """
  63. _v1_transform_cls = _transforms.RandomApply
  64. def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
  65. super().__init__()
  66. if not isinstance(transforms, (Sequence, nn.ModuleList)):
  67. raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
  68. self.transforms = transforms
  69. if not (0.0 <= p <= 1.0):
  70. raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
  71. self.p = p
  72. def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
  73. return {"transforms": self.transforms, "p": self.p}
  74. def forward(self, *inputs: Any) -> Any:
  75. sample = inputs if len(inputs) > 1 else inputs[0]
  76. if torch.rand(1) >= self.p:
  77. return sample
  78. for transform in self.transforms:
  79. sample = transform(sample)
  80. return sample
  81. def extra_repr(self) -> str:
  82. format_string = []
  83. for t in self.transforms:
  84. format_string.append(f" {t}")
  85. return "\n".join(format_string)
  86. class RandomChoice(Transform):
  87. """[BETA] Apply single transformation randomly picked from a list.
  88. .. v2betastatus:: RandomChoice transform
  89. This transform does not support torchscript.
  90. Args:
  91. transforms (sequence or torch.nn.Module): list of transformations
  92. p (list of floats or None, optional): probability of each transform being picked.
  93. If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
  94. (default), all transforms have the same probability.
  95. """
  96. def __init__(
  97. self,
  98. transforms: Sequence[Callable],
  99. p: Optional[List[float]] = None,
  100. ) -> None:
  101. if not isinstance(transforms, Sequence):
  102. raise TypeError("Argument transforms should be a sequence of callables")
  103. if p is None:
  104. p = [1] * len(transforms)
  105. elif len(p) != len(transforms):
  106. raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
  107. super().__init__()
  108. self.transforms = transforms
  109. total = sum(p)
  110. self.p = [prob / total for prob in p]
  111. def forward(self, *inputs: Any) -> Any:
  112. idx = int(torch.multinomial(torch.tensor(self.p), 1))
  113. transform = self.transforms[idx]
  114. return transform(*inputs)
  115. class RandomOrder(Transform):
  116. """[BETA] Apply a list of transformations in a random order.
  117. .. v2betastatus:: RandomOrder transform
  118. This transform does not support torchscript.
  119. Args:
  120. transforms (sequence or torch.nn.Module): list of transformations
  121. """
  122. def __init__(self, transforms: Sequence[Callable]) -> None:
  123. if not isinstance(transforms, Sequence):
  124. raise TypeError("Argument transforms should be a sequence of callables")
  125. super().__init__()
  126. self.transforms = transforms
  127. def forward(self, *inputs: Any) -> Any:
  128. sample = inputs if len(inputs) > 1 else inputs[0]
  129. for idx in torch.randperm(len(self.transforms)):
  130. transform = self.transforms[idx]
  131. sample = transform(sample)
  132. return sample