import math from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import PIL.Image import torch from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms import _functional_tensor as _FT from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._meta import get_size from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video] class _AutoAugmentBase(Transform): def __init__( self, *, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__() self.interpolation = _check_interpolation(interpolation) self.fill = fill self._fill = _setup_fill_arg(fill) def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if isinstance(params["fill"], dict): raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.") return params def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: keys = tuple(dct.keys()) key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] def _flatten_and_extract_image_or_video( self, inputs: Any, unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask), ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]: flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) needs_transform_list = self._needs_transform_list(flat_inputs) image_or_videos = [] for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)): if needs_transform and check_type( inpt, ( tv_tensors.Image, PIL.Image.Image, is_pure_tensor, tv_tensors.Video, ), ): image_or_videos.append((idx, inpt)) elif isinstance(inpt, unsupported_types): raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") if not image_or_videos: raise TypeError("Found no image in the sample.") if len(image_or_videos) > 1: raise TypeError( f"Auto augment transformations are only properly defined for a single image or video, " f"but found {len(image_or_videos)}." ) idx, image_or_video = image_or_videos[0] return (flat_inputs, spec, idx), image_or_video def _unflatten_and_insert_image_or_video( self, flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], image_or_video: ImageOrVideo, ) -> Any: flat_inputs, spec, idx = flat_inputs_with_spec flat_inputs[idx] = image_or_video return tree_unflatten(flat_inputs, spec) def _apply_image_or_video_transform( self, image: ImageOrVideo, transform_id: str, magnitude: float, interpolation: Union[InterpolationMode, int], fill: Dict[Union[Type, str], _FillTypeJIT], ) -> ImageOrVideo: fill_ = _get_fill(fill, type(image)) if transform_id == "Identity": return image elif transform_id == "ShearX": # magnitude should be arctan(magnitude) # official autoaug: (1, level, 0, 0, 1, 0) # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 # compared to # torchvision: (1, tan(level), 0, 0, 1, 0) # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 return F.affine( image, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(math.atan(magnitude)), 0.0], interpolation=interpolation, fill=fill_, center=[0, 0], ) elif transform_id == "ShearY": # magnitude should be arctan(magnitude) # See above return F.affine( image, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(math.atan(magnitude))], interpolation=interpolation, fill=fill_, center=[0, 0], ) elif transform_id == "TranslateX": return F.affine( image, angle=0.0, translate=[int(magnitude), 0], scale=1.0, interpolation=interpolation, shear=[0.0, 0.0], fill=fill_, ) elif transform_id == "TranslateY": return F.affine( image, angle=0.0, translate=[0, int(magnitude)], scale=1.0, interpolation=interpolation, shear=[0.0, 0.0], fill=fill_, ) elif transform_id == "Rotate": return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_) elif transform_id == "Brightness": return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) elif transform_id == "Color": return F.adjust_saturation(image, saturation_factor=1.0 + magnitude) elif transform_id == "Contrast": return F.adjust_contrast(image, contrast_factor=1.0 + magnitude) elif transform_id == "Sharpness": return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude) elif transform_id == "Posterize": return F.posterize(image, bits=int(magnitude)) elif transform_id == "Solarize": bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0 return F.solarize(image, threshold=bound * magnitude) elif transform_id == "AutoContrast": return F.autocontrast(image) elif transform_id == "Equalize": return F.equalize(image) elif transform_id == "Invert": return F.invert(image) else: raise ValueError(f"No transform available for {transform_id}") class AutoAugment(_AutoAugmentBase): r"""[BETA] AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. .. v2betastatus:: AutoAugment transform This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: policy (AutoAugmentPolicy, optional): Desired policy enum defined by :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """ _v1_transform_cls = _transforms.AutoAugment _AUGMENTATION_SPACE = { "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": ( lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), True, ), "TranslateY": ( lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), True, ), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), "Invert": (lambda num_bins, height, width: None, False), } def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy self._policies = self._get_policies(policy) def _get_policies( self, policy: AutoAugmentPolicy ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: if policy == AutoAugmentPolicy.IMAGENET: return [ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), (("Equalize", 0.8, None), ("Equalize", 0.6, None)), (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), (("Rotate", 0.8, 8), ("Color", 0.4, 0)), (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), (("Equalize", 0.0, None), ("Equalize", 0.8, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), (("Color", 0.6, 4), ("Contrast", 1.0, 8)), (("Rotate", 0.8, 8), ("Color", 1.0, 2)), (("Color", 0.8, 8), ("Solarize", 0.8, 7)), (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), (("Color", 0.4, 0), ("Equalize", 0.6, None)), (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), (("Color", 0.6, 4), ("Contrast", 1.0, 8)), (("Equalize", 0.8, None), ("Equalize", 0.6, None)), ] elif policy == AutoAugmentPolicy.CIFAR10: return [ (("Invert", 0.1, None), ("Contrast", 0.2, 6)), (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), (("Color", 0.4, 3), ("Brightness", 0.6, 7)), (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), (("Equalize", 0.6, None), ("Equalize", 0.5, None)), (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), (("Brightness", 0.9, 6), ("Color", 0.2, 8)), (("Solarize", 0.5, 2), ("Invert", 0.0, None)), (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), (("Equalize", 0.2, None), ("Equalize", 0.6, None)), (("Color", 0.9, 9), ("Equalize", 0.6, None)), (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), (("Brightness", 0.1, 3), ("Color", 0.7, 0)), (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), (("Equalize", 0.8, None), ("Invert", 0.1, None)), (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), ] elif policy == AutoAugmentPolicy.SVHN: return [ (("ShearX", 0.9, 4), ("Invert", 0.2, None)), (("ShearY", 0.9, 8), ("Invert", 0.7, None)), (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), (("ShearY", 0.9, 8), ("Invert", 0.4, None)), (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), (("ShearY", 0.8, 8), ("Invert", 0.7, None)), (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), (("Invert", 0.6, None), ("Rotate", 0.8, 4)), (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), (("ShearX", 0.1, 6), ("Invert", 0.6, None)), (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), (("ShearY", 0.8, 4), ("Invert", 0.8, None)), (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), (("ShearX", 0.7, 2), ("Invert", 0.1, None)), ] else: raise ValueError(f"The provided policy {policy} is not recognized.") def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) policy = self._policies[int(torch.randint(len(self._policies), ()))] for transform_id, probability, magnitude_idx in policy: if not torch.rand(()) <= probability: continue magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 image_or_video = self._apply_image_or_video_transform( image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) class RandAugment(_AutoAugmentBase): r"""[BETA] RandAugment data augmentation method based on `"RandAugment: Practical automated data augmentation with a reduced search space" `_. .. v2betastatus:: RandAugment transform This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: num_ops (int, optional): Number of augmentation transformations to apply sequentially. magnitude (int, optional): Magnitude for all the transformations. num_magnitude_bins (int, optional): The number of different magnitude values. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """ _v1_transform_cls = _transforms.RandAugment _AUGMENTATION_SPACE = { "Identity": (lambda num_bins, height, width: None, False), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": ( lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), True, ), "TranslateY": ( lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), True, ), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), } def __init__( self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[self.magnitude]) if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 image_or_video = self._apply_image_or_video_transform( image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) class TrivialAugmentWide(_AutoAugmentBase): r"""[BETA] Dataset-independent data-augmentation with TrivialAugment Wide, as described in `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `_. .. v2betastatus:: TrivialAugmentWide transform This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: num_magnitude_bins (int, optional): The number of different magnitude values. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """ _v1_transform_cls = _transforms.TrivialAugmentWide _AUGMENTATION_SPACE = { "Identity": (lambda num_bins, height, width: None, False), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True), "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), } def __init__( self, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 image_or_video = self._apply_image_or_video_transform( image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) class AugMix(_AutoAugmentBase): r"""[BETA] AugMix data augmentation method based on `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" `_. .. v2betastatus:: AugMix transform This transformation works on images and videos only. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: severity (int, optional): The severity of base augmentation operators. Default is ``3``. mixture_width (int, optional): The number of augmentation chains. Default is ``3``. chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. Default is ``-1``. alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``. all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. interpolation (InterpolationMode, optional): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. """ _v1_transform_cls = _transforms.AugMix _PARTIAL_AUGMENTATION_SPACE = { "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True), "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Posterize": ( lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), } _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = { **_PARTIAL_AUGMENTATION_SPACE, "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), } def __init__( self, severity: int = 3, mixture_width: int = 3, chain_depth: int = -1, alpha: float = 1.0, all_ops: bool = True, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 if not (1 <= severity <= self._PARAMETER_MAX): raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") self.severity = severity self.mixture_width = mixture_width self.chain_depth = chain_depth self.alpha = alpha self.all_ops = all_ops def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: # Must be on a separate method so that we can overwrite it in tests. return torch._sample_dirichlet(params) def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(orig_image_or_video) if isinstance(orig_image_or_video, torch.Tensor): image_or_video = orig_image_or_video else: # isinstance(inpt, PIL.Image.Image): image_or_video = F.pil_to_tensor(orig_image_or_video) augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE orig_dims = list(image_or_video.shape) expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4 batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of # augmented image or video. m = self._sample_dirichlet( torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) ) # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos. combined_weights = self._sample_dirichlet( torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) ) * m[:, 1].reshape([batch_dims[0], -1]) mix = m[:, 0].reshape(batch_dims) * batch for i in range(self.mixture_width): aug = batch depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) for _ in range(depth): transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 aug = self._apply_image_or_video_transform( aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)): mix = tv_tensors.wrap(mix, like=orig_image_or_video) elif isinstance(orig_image_or_video, PIL.Image.Image): mix = F.to_pil_image(mix) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)