123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635 |
- import math
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
- import PIL.Image
- import torch
- from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
- from torchvision import transforms as _transforms, tv_tensors
- from torchvision.transforms import _functional_tensor as _FT
- from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
- from torchvision.transforms.v2.functional._geometry import _check_interpolation
- from torchvision.transforms.v2.functional._meta import get_size
- from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
- from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
- ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
- class _AutoAugmentBase(Transform):
- def __init__(
- self,
- *,
- interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
- fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
- ) -> None:
- super().__init__()
- self.interpolation = _check_interpolation(interpolation)
- self.fill = fill
- self._fill = _setup_fill_arg(fill)
- def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
- params = super()._extract_params_for_v1_transform()
- if isinstance(params["fill"], dict):
- raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
- return params
- def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
- keys = tuple(dct.keys())
- key = keys[int(torch.randint(len(keys), ()))]
- return key, dct[key]
- def _flatten_and_extract_image_or_video(
- self,
- inputs: Any,
- unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
- ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
- flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
- needs_transform_list = self._needs_transform_list(flat_inputs)
- image_or_videos = []
- for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
- if needs_transform and check_type(
- inpt,
- (
- tv_tensors.Image,
- PIL.Image.Image,
- is_pure_tensor,
- tv_tensors.Video,
- ),
- ):
- image_or_videos.append((idx, inpt))
- elif isinstance(inpt, unsupported_types):
- raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
- if not image_or_videos:
- raise TypeError("Found no image in the sample.")
- if len(image_or_videos) > 1:
- raise TypeError(
- f"Auto augment transformations are only properly defined for a single image or video, "
- f"but found {len(image_or_videos)}."
- )
- idx, image_or_video = image_or_videos[0]
- return (flat_inputs, spec, idx), image_or_video
- def _unflatten_and_insert_image_or_video(
- self,
- flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
- image_or_video: ImageOrVideo,
- ) -> Any:
- flat_inputs, spec, idx = flat_inputs_with_spec
- flat_inputs[idx] = image_or_video
- return tree_unflatten(flat_inputs, spec)
- def _apply_image_or_video_transform(
- self,
- image: ImageOrVideo,
- transform_id: str,
- magnitude: float,
- interpolation: Union[InterpolationMode, int],
- fill: Dict[Union[Type, str], _FillTypeJIT],
- ) -> ImageOrVideo:
- fill_ = _get_fill(fill, type(image))
- if transform_id == "Identity":
- return image
- elif transform_id == "ShearX":
- # magnitude should be arctan(magnitude)
- # official autoaug: (1, level, 0, 0, 1, 0)
- # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
- # compared to
- # torchvision: (1, tan(level), 0, 0, 1, 0)
- # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
- return F.affine(
- image,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[math.degrees(math.atan(magnitude)), 0.0],
- interpolation=interpolation,
- fill=fill_,
- center=[0, 0],
- )
- elif transform_id == "ShearY":
- # magnitude should be arctan(magnitude)
- # See above
- return F.affine(
- image,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[0.0, math.degrees(math.atan(magnitude))],
- interpolation=interpolation,
- fill=fill_,
- center=[0, 0],
- )
- elif transform_id == "TranslateX":
- return F.affine(
- image,
- angle=0.0,
- translate=[int(magnitude), 0],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill_,
- )
- elif transform_id == "TranslateY":
- return F.affine(
- image,
- angle=0.0,
- translate=[0, int(magnitude)],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill_,
- )
- elif transform_id == "Rotate":
- return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
- elif transform_id == "Brightness":
- return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
- elif transform_id == "Color":
- return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
- elif transform_id == "Contrast":
- return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
- elif transform_id == "Sharpness":
- return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
- elif transform_id == "Posterize":
- return F.posterize(image, bits=int(magnitude))
- elif transform_id == "Solarize":
- bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
- return F.solarize(image, threshold=bound * magnitude)
- elif transform_id == "AutoContrast":
- return F.autocontrast(image)
- elif transform_id == "Equalize":
- return F.equalize(image)
- elif transform_id == "Invert":
- return F.invert(image)
- else:
- raise ValueError(f"No transform available for {transform_id}")
- class AutoAugment(_AutoAugmentBase):
- r"""[BETA] AutoAugment data augmentation method based on
- `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
- .. v2betastatus:: AutoAugment transform
- This transformation works on images and videos only.
- If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- policy (AutoAugmentPolicy, optional): Desired policy enum defined by
- :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
- interpolation (InterpolationMode, optional): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- _v1_transform_cls = _transforms.AutoAugment
- _AUGMENTATION_SPACE = {
- "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
- "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
- "TranslateX": (
- lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
- True,
- ),
- "TranslateY": (
- lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
- True,
- ),
- "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
- "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Posterize": (
- lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
- False,
- ),
- "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
- "AutoContrast": (lambda num_bins, height, width: None, False),
- "Equalize": (lambda num_bins, height, width: None, False),
- "Invert": (lambda num_bins, height, width: None, False),
- }
- def __init__(
- self,
- policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
- interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
- fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
- ) -> None:
- super().__init__(interpolation=interpolation, fill=fill)
- self.policy = policy
- self._policies = self._get_policies(policy)
- def _get_policies(
- self, policy: AutoAugmentPolicy
- ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
- if policy == AutoAugmentPolicy.IMAGENET:
- return [
- (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
- (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
- (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
- (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
- (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
- (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
- (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
- (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
- (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
- (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
- (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
- (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
- (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
- (("Invert", 0.6, None), ("Equalize", 1.0, None)),
- (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
- (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
- (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
- (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
- (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
- (("Color", 0.4, 0), ("Equalize", 0.6, None)),
- (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
- (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
- (("Invert", 0.6, None), ("Equalize", 1.0, None)),
- (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
- (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
- ]
- elif policy == AutoAugmentPolicy.CIFAR10:
- return [
- (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
- (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
- (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
- (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
- (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
- (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
- (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
- (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
- (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
- (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
- (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
- (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
- (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
- (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
- (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
- (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
- (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
- (("Color", 0.9, 9), ("Equalize", 0.6, None)),
- (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
- (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
- (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
- (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
- (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
- (("Equalize", 0.8, None), ("Invert", 0.1, None)),
- (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
- ]
- elif policy == AutoAugmentPolicy.SVHN:
- return [
- (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
- (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
- (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
- (("Invert", 0.9, None), ("Equalize", 0.6, None)),
- (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
- (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
- (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
- (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
- (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
- (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
- (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
- (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
- (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
- (("Invert", 0.9, None), ("Equalize", 0.6, None)),
- (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
- (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
- (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
- (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
- (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
- (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
- (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
- (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
- (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
- (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
- (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
- ]
- else:
- raise ValueError(f"The provided policy {policy} is not recognized.")
- def forward(self, *inputs: Any) -> Any:
- flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
- height, width = get_size(image_or_video)
- policy = self._policies[int(torch.randint(len(self._policies), ()))]
- for transform_id, probability, magnitude_idx in policy:
- if not torch.rand(()) <= probability:
- continue
- magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
- magnitudes = magnitudes_fn(10, height, width)
- if magnitudes is not None:
- magnitude = float(magnitudes[magnitude_idx])
- if signed and torch.rand(()) <= 0.5:
- magnitude *= -1
- else:
- magnitude = 0.0
- image_or_video = self._apply_image_or_video_transform(
- image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
- )
- return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
- class RandAugment(_AutoAugmentBase):
- r"""[BETA] RandAugment data augmentation method based on
- `"RandAugment: Practical automated data augmentation with a reduced search space"
- <https://arxiv.org/abs/1909.13719>`_.
- .. v2betastatus:: RandAugment transform
- This transformation works on images and videos only.
- If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- num_ops (int, optional): Number of augmentation transformations to apply sequentially.
- magnitude (int, optional): Magnitude for all the transformations.
- num_magnitude_bins (int, optional): The number of different magnitude values.
- interpolation (InterpolationMode, optional): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- _v1_transform_cls = _transforms.RandAugment
- _AUGMENTATION_SPACE = {
- "Identity": (lambda num_bins, height, width: None, False),
- "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
- "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
- "TranslateX": (
- lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
- True,
- ),
- "TranslateY": (
- lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
- True,
- ),
- "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
- "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Posterize": (
- lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
- False,
- ),
- "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
- "AutoContrast": (lambda num_bins, height, width: None, False),
- "Equalize": (lambda num_bins, height, width: None, False),
- }
- def __init__(
- self,
- num_ops: int = 2,
- magnitude: int = 9,
- num_magnitude_bins: int = 31,
- interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
- fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
- ) -> None:
- super().__init__(interpolation=interpolation, fill=fill)
- self.num_ops = num_ops
- self.magnitude = magnitude
- self.num_magnitude_bins = num_magnitude_bins
- def forward(self, *inputs: Any) -> Any:
- flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
- height, width = get_size(image_or_video)
- for _ in range(self.num_ops):
- transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
- magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
- if magnitudes is not None:
- magnitude = float(magnitudes[self.magnitude])
- if signed and torch.rand(()) <= 0.5:
- magnitude *= -1
- else:
- magnitude = 0.0
- image_or_video = self._apply_image_or_video_transform(
- image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
- )
- return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
- class TrivialAugmentWide(_AutoAugmentBase):
- r"""[BETA] Dataset-independent data-augmentation with TrivialAugment Wide, as described in
- `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
- .. v2betastatus:: TrivialAugmentWide transform
- This transformation works on images and videos only.
- If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- num_magnitude_bins (int, optional): The number of different magnitude values.
- interpolation (InterpolationMode, optional): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- _v1_transform_cls = _transforms.TrivialAugmentWide
- _AUGMENTATION_SPACE = {
- "Identity": (lambda num_bins, height, width: None, False),
- "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
- "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
- "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
- "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
- "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
- "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
- "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
- "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
- "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
- "Posterize": (
- lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
- False,
- ),
- "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
- "AutoContrast": (lambda num_bins, height, width: None, False),
- "Equalize": (lambda num_bins, height, width: None, False),
- }
- def __init__(
- self,
- num_magnitude_bins: int = 31,
- interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
- fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
- ):
- super().__init__(interpolation=interpolation, fill=fill)
- self.num_magnitude_bins = num_magnitude_bins
- def forward(self, *inputs: Any) -> Any:
- flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
- height, width = get_size(image_or_video)
- transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
- magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
- if magnitudes is not None:
- magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
- if signed and torch.rand(()) <= 0.5:
- magnitude *= -1
- else:
- magnitude = 0.0
- image_or_video = self._apply_image_or_video_transform(
- image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
- )
- return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
- class AugMix(_AutoAugmentBase):
- r"""[BETA] AugMix data augmentation method based on
- `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
- .. v2betastatus:: AugMix transform
- This transformation works on images and videos only.
- If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- severity (int, optional): The severity of base augmentation operators. Default is ``3``.
- mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
- chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
- Default is ``-1``.
- alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
- all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
- interpolation (InterpolationMode, optional): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- _v1_transform_cls = _transforms.AugMix
- _PARTIAL_AUGMENTATION_SPACE = {
- "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
- "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
- "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
- "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
- "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
- "Posterize": (
- lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
- False,
- ),
- "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
- "AutoContrast": (lambda num_bins, height, width: None, False),
- "Equalize": (lambda num_bins, height, width: None, False),
- }
- _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
- **_PARTIAL_AUGMENTATION_SPACE,
- "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
- }
- def __init__(
- self,
- severity: int = 3,
- mixture_width: int = 3,
- chain_depth: int = -1,
- alpha: float = 1.0,
- all_ops: bool = True,
- interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
- fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
- ) -> None:
- super().__init__(interpolation=interpolation, fill=fill)
- self._PARAMETER_MAX = 10
- if not (1 <= severity <= self._PARAMETER_MAX):
- raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
- self.severity = severity
- self.mixture_width = mixture_width
- self.chain_depth = chain_depth
- self.alpha = alpha
- self.all_ops = all_ops
- def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
- # Must be on a separate method so that we can overwrite it in tests.
- return torch._sample_dirichlet(params)
- def forward(self, *inputs: Any) -> Any:
- flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
- height, width = get_size(orig_image_or_video)
- if isinstance(orig_image_or_video, torch.Tensor):
- image_or_video = orig_image_or_video
- else: # isinstance(inpt, PIL.Image.Image):
- image_or_video = F.pil_to_tensor(orig_image_or_video)
- augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
- orig_dims = list(image_or_video.shape)
- expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
- batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
- batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
- # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
- # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
- # augmented image or video.
- m = self._sample_dirichlet(
- torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
- )
- # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
- combined_weights = self._sample_dirichlet(
- torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
- ) * m[:, 1].reshape([batch_dims[0], -1])
- mix = m[:, 0].reshape(batch_dims) * batch
- for i in range(self.mixture_width):
- aug = batch
- depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
- for _ in range(depth):
- transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
- magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
- if magnitudes is not None:
- magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
- if signed and torch.rand(()) <= 0.5:
- magnitude *= -1
- else:
- magnitude = 0.0
- aug = self._apply_image_or_video_transform(
- aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
- )
- mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
- mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
- if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
- mix = tv_tensors.wrap(mix, like=orig_image_or_video)
- elif isinstance(orig_image_or_video, PIL.Image.Image):
- mix = F.to_pil_image(mix)
- return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
|