123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- from typing import Any, Callable, Dict, List, Optional, Sequence, Union
- import torch
- from torch import nn
- from torchvision import transforms as _transforms
- from torchvision.transforms.v2 import Transform
- class Compose(Transform):
- """[BETA] Composes several transforms together.
- .. v2betastatus:: Compose transform
- This transform does not support torchscript.
- Please, see the note below.
- Args:
- transforms (list of ``Transform`` objects): list of transforms to compose.
- Example:
- >>> transforms.Compose([
- >>> transforms.CenterCrop(10),
- >>> transforms.PILToTensor(),
- >>> transforms.ConvertImageDtype(torch.float),
- >>> ])
- .. note::
- In order to script the transformations, please use ``torch.nn.Sequential`` as below.
- >>> transforms = torch.nn.Sequential(
- >>> transforms.CenterCrop(10),
- >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- >>> )
- >>> scripted_transforms = torch.jit.script(transforms)
- Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
- `lambda` functions or ``PIL.Image``.
- """
- def __init__(self, transforms: Sequence[Callable]) -> None:
- super().__init__()
- if not isinstance(transforms, Sequence):
- raise TypeError("Argument transforms should be a sequence of callables")
- elif not transforms:
- raise ValueError("Pass at least one transform")
- self.transforms = transforms
- def forward(self, *inputs: Any) -> Any:
- needs_unpacking = len(inputs) > 1
- for transform in self.transforms:
- outputs = transform(*inputs)
- inputs = outputs if needs_unpacking else (outputs,)
- return outputs
- def extra_repr(self) -> str:
- format_string = []
- for t in self.transforms:
- format_string.append(f" {t}")
- return "\n".join(format_string)
- class RandomApply(Transform):
- """[BETA] Apply randomly a list of transformations with a given probability.
- .. v2betastatus:: RandomApply transform
- .. note::
- In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
- transforms as shown below:
- >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
- >>> transforms.ColorJitter(),
- >>> ]), p=0.3)
- >>> scripted_transforms = torch.jit.script(transforms)
- Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
- `lambda` functions or ``PIL.Image``.
- Args:
- transforms (sequence or torch.nn.Module): list of transformations
- p (float): probability of applying the list of transforms
- """
- _v1_transform_cls = _transforms.RandomApply
- def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
- super().__init__()
- if not isinstance(transforms, (Sequence, nn.ModuleList)):
- raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
- self.transforms = transforms
- if not (0.0 <= p <= 1.0):
- raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
- self.p = p
- def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
- return {"transforms": self.transforms, "p": self.p}
- def forward(self, *inputs: Any) -> Any:
- sample = inputs if len(inputs) > 1 else inputs[0]
- if torch.rand(1) >= self.p:
- return sample
- for transform in self.transforms:
- sample = transform(sample)
- return sample
- def extra_repr(self) -> str:
- format_string = []
- for t in self.transforms:
- format_string.append(f" {t}")
- return "\n".join(format_string)
- class RandomChoice(Transform):
- """[BETA] Apply single transformation randomly picked from a list.
- .. v2betastatus:: RandomChoice transform
- This transform does not support torchscript.
- Args:
- transforms (sequence or torch.nn.Module): list of transformations
- p (list of floats or None, optional): probability of each transform being picked.
- If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
- (default), all transforms have the same probability.
- """
- def __init__(
- self,
- transforms: Sequence[Callable],
- p: Optional[List[float]] = None,
- ) -> None:
- if not isinstance(transforms, Sequence):
- raise TypeError("Argument transforms should be a sequence of callables")
- if p is None:
- p = [1] * len(transforms)
- elif len(p) != len(transforms):
- raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
- super().__init__()
- self.transforms = transforms
- total = sum(p)
- self.p = [prob / total for prob in p]
- def forward(self, *inputs: Any) -> Any:
- idx = int(torch.multinomial(torch.tensor(self.p), 1))
- transform = self.transforms[idx]
- return transform(*inputs)
- class RandomOrder(Transform):
- """[BETA] Apply a list of transformations in a random order.
- .. v2betastatus:: RandomOrder transform
- This transform does not support torchscript.
- Args:
- transforms (sequence or torch.nn.Module): list of transformations
- """
- def __init__(self, transforms: Sequence[Callable]) -> None:
- if not isinstance(transforms, Sequence):
- raise TypeError("Argument transforms should be a sequence of callables")
- super().__init__()
- self.transforms = transforms
- def forward(self, *inputs: Any) -> Any:
- sample = inputs if len(inputs) > 1 else inputs[0]
- for idx in torch.randperm(len(self.transforms)):
- transform = self.transforms[idx]
- sample = transform(sample)
- return sample
|