123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615 |
- import math
- from enum import Enum
- from typing import Dict, List, Optional, Tuple
- import torch
- from torch import Tensor
- from . import functional as F, InterpolationMode
- __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
- def _apply_op(
- img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
- ):
- if op_name == "ShearX":
- # magnitude should be arctan(magnitude)
- # official autoaug: (1, level, 0, 0, 1, 0)
- # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
- # compared to
- # torchvision: (1, tan(level), 0, 0, 1, 0)
- # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
- img = F.affine(
- img,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[math.degrees(math.atan(magnitude)), 0.0],
- interpolation=interpolation,
- fill=fill,
- center=[0, 0],
- )
- elif op_name == "ShearY":
- # magnitude should be arctan(magnitude)
- # See above
- img = F.affine(
- img,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[0.0, math.degrees(math.atan(magnitude))],
- interpolation=interpolation,
- fill=fill,
- center=[0, 0],
- )
- elif op_name == "TranslateX":
- img = F.affine(
- img,
- angle=0.0,
- translate=[int(magnitude), 0],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill,
- )
- elif op_name == "TranslateY":
- img = F.affine(
- img,
- angle=0.0,
- translate=[0, int(magnitude)],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill,
- )
- elif op_name == "Rotate":
- img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
- elif op_name == "Brightness":
- img = F.adjust_brightness(img, 1.0 + magnitude)
- elif op_name == "Color":
- img = F.adjust_saturation(img, 1.0 + magnitude)
- elif op_name == "Contrast":
- img = F.adjust_contrast(img, 1.0 + magnitude)
- elif op_name == "Sharpness":
- img = F.adjust_sharpness(img, 1.0 + magnitude)
- elif op_name == "Posterize":
- img = F.posterize(img, int(magnitude))
- elif op_name == "Solarize":
- img = F.solarize(img, magnitude)
- elif op_name == "AutoContrast":
- img = F.autocontrast(img)
- elif op_name == "Equalize":
- img = F.equalize(img)
- elif op_name == "Invert":
- img = F.invert(img)
- elif op_name == "Identity":
- pass
- else:
- raise ValueError(f"The provided operator {op_name} is not recognized.")
- return img
- class AutoAugmentPolicy(Enum):
- """AutoAugment policies learned on different datasets.
- Available policies are IMAGENET, CIFAR10 and SVHN.
- """
- IMAGENET = "imagenet"
- CIFAR10 = "cifar10"
- SVHN = "svhn"
- # FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
- class AutoAugment(torch.nn.Module):
- r"""AutoAugment data augmentation method based on
- `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
- If the image is torch Tensor, it should be of type torch.uint8, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- policy (AutoAugmentPolicy): Desired policy enum defined by
- :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- def __init__(
- self,
- policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
- interpolation: InterpolationMode = InterpolationMode.NEAREST,
- fill: Optional[List[float]] = None,
- ) -> None:
- super().__init__()
- self.policy = policy
- self.interpolation = interpolation
- self.fill = fill
- self.policies = self._get_policies(policy)
- def _get_policies(
- self, policy: AutoAugmentPolicy
- ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
- if policy == AutoAugmentPolicy.IMAGENET:
- return [
- (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
- (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
- (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
- (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
- (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
- (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
- (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
- (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
- (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
- (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
- (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
- (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
- (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
- (("Invert", 0.6, None), ("Equalize", 1.0, None)),
- (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
- (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
- (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
- (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
- (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
- (("Color", 0.4, 0), ("Equalize", 0.6, None)),
- (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
- (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
- (("Invert", 0.6, None), ("Equalize", 1.0, None)),
- (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
- (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
- ]
- elif policy == AutoAugmentPolicy.CIFAR10:
- return [
- (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
- (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
- (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
- (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
- (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
- (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
- (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
- (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
- (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
- (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
- (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
- (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
- (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
- (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
- (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
- (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
- (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
- (("Color", 0.9, 9), ("Equalize", 0.6, None)),
- (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
- (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
- (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
- (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
- (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
- (("Equalize", 0.8, None), ("Invert", 0.1, None)),
- (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
- ]
- elif policy == AutoAugmentPolicy.SVHN:
- return [
- (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
- (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
- (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
- (("Invert", 0.9, None), ("Equalize", 0.6, None)),
- (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
- (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
- (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
- (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
- (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
- (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
- (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
- (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
- (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
- (("Invert", 0.9, None), ("Equalize", 0.6, None)),
- (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
- (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
- (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
- (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
- (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
- (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
- (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
- (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
- (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
- (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
- (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
- ]
- else:
- raise ValueError(f"The provided policy {policy} is not recognized.")
- def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
- return {
- # op_name: (magnitudes, signed)
- "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
- "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
- "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
- "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
- "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
- "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
- "Color": (torch.linspace(0.0, 0.9, num_bins), True),
- "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
- "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
- "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
- "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
- "AutoContrast": (torch.tensor(0.0), False),
- "Equalize": (torch.tensor(0.0), False),
- "Invert": (torch.tensor(0.0), False),
- }
- @staticmethod
- def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
- """Get parameters for autoaugment transformation
- Returns:
- params required by the autoaugment transformation
- """
- policy_id = int(torch.randint(transform_num, (1,)).item())
- probs = torch.rand((2,))
- signs = torch.randint(2, (2,))
- return policy_id, probs, signs
- def forward(self, img: Tensor) -> Tensor:
- """
- img (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: AutoAugmented image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- elif fill is not None:
- fill = [float(f) for f in fill]
- transform_id, probs, signs = self.get_params(len(self.policies))
- op_meta = self._augmentation_space(10, (height, width))
- for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
- if probs[i] <= p:
- magnitudes, signed = op_meta[op_name]
- magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
- if signed and signs[i] == 0:
- magnitude *= -1.0
- img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"
- class RandAugment(torch.nn.Module):
- r"""RandAugment data augmentation method based on
- `"RandAugment: Practical automated data augmentation with a reduced search space"
- <https://arxiv.org/abs/1909.13719>`_.
- If the image is torch Tensor, it should be of type torch.uint8, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- num_ops (int): Number of augmentation transformations to apply sequentially.
- magnitude (int): Magnitude for all the transformations.
- num_magnitude_bins (int): The number of different magnitude values.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- def __init__(
- self,
- num_ops: int = 2,
- magnitude: int = 9,
- num_magnitude_bins: int = 31,
- interpolation: InterpolationMode = InterpolationMode.NEAREST,
- fill: Optional[List[float]] = None,
- ) -> None:
- super().__init__()
- self.num_ops = num_ops
- self.magnitude = magnitude
- self.num_magnitude_bins = num_magnitude_bins
- self.interpolation = interpolation
- self.fill = fill
- def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
- return {
- # op_name: (magnitudes, signed)
- "Identity": (torch.tensor(0.0), False),
- "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
- "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
- "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
- "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
- "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
- "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
- "Color": (torch.linspace(0.0, 0.9, num_bins), True),
- "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
- "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
- "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
- "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
- "AutoContrast": (torch.tensor(0.0), False),
- "Equalize": (torch.tensor(0.0), False),
- }
- def forward(self, img: Tensor) -> Tensor:
- """
- img (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: Transformed image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- elif fill is not None:
- fill = [float(f) for f in fill]
- op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
- for _ in range(self.num_ops):
- op_index = int(torch.randint(len(op_meta), (1,)).item())
- op_name = list(op_meta.keys())[op_index]
- magnitudes, signed = op_meta[op_name]
- magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
- if signed and torch.randint(2, (1,)):
- magnitude *= -1.0
- img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
- return img
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}("
- f"num_ops={self.num_ops}"
- f", magnitude={self.magnitude}"
- f", num_magnitude_bins={self.num_magnitude_bins}"
- f", interpolation={self.interpolation}"
- f", fill={self.fill}"
- f")"
- )
- return s
- class TrivialAugmentWide(torch.nn.Module):
- r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
- `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
- If the image is torch Tensor, it should be of type torch.uint8, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- num_magnitude_bins (int): The number of different magnitude values.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- def __init__(
- self,
- num_magnitude_bins: int = 31,
- interpolation: InterpolationMode = InterpolationMode.NEAREST,
- fill: Optional[List[float]] = None,
- ) -> None:
- super().__init__()
- self.num_magnitude_bins = num_magnitude_bins
- self.interpolation = interpolation
- self.fill = fill
- def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
- return {
- # op_name: (magnitudes, signed)
- "Identity": (torch.tensor(0.0), False),
- "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
- "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
- "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
- "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
- "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
- "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
- "Color": (torch.linspace(0.0, 0.99, num_bins), True),
- "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
- "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
- "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
- "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
- "AutoContrast": (torch.tensor(0.0), False),
- "Equalize": (torch.tensor(0.0), False),
- }
- def forward(self, img: Tensor) -> Tensor:
- """
- img (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: Transformed image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- elif fill is not None:
- fill = [float(f) for f in fill]
- op_meta = self._augmentation_space(self.num_magnitude_bins)
- op_index = int(torch.randint(len(op_meta), (1,)).item())
- op_name = list(op_meta.keys())[op_index]
- magnitudes, signed = op_meta[op_name]
- magnitude = (
- float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
- if magnitudes.ndim > 0
- else 0.0
- )
- if signed and torch.randint(2, (1,)):
- magnitude *= -1.0
- return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}("
- f"num_magnitude_bins={self.num_magnitude_bins}"
- f", interpolation={self.interpolation}"
- f", fill={self.fill}"
- f")"
- )
- return s
- class AugMix(torch.nn.Module):
- r"""AugMix data augmentation method based on
- `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
- If the image is torch Tensor, it should be of type torch.uint8, and it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- severity (int): The severity of base augmentation operators. Default is ``3``.
- mixture_width (int): The number of augmentation chains. Default is ``3``.
- chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
- Default is ``-1``.
- alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
- all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- fill (sequence or number, optional): Pixel fill value for the area outside the transformed
- image. If given a number, the value is used for all bands respectively.
- """
- def __init__(
- self,
- severity: int = 3,
- mixture_width: int = 3,
- chain_depth: int = -1,
- alpha: float = 1.0,
- all_ops: bool = True,
- interpolation: InterpolationMode = InterpolationMode.BILINEAR,
- fill: Optional[List[float]] = None,
- ) -> None:
- super().__init__()
- self._PARAMETER_MAX = 10
- if not (1 <= severity <= self._PARAMETER_MAX):
- raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
- self.severity = severity
- self.mixture_width = mixture_width
- self.chain_depth = chain_depth
- self.alpha = alpha
- self.all_ops = all_ops
- self.interpolation = interpolation
- self.fill = fill
- def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
- s = {
- # op_name: (magnitudes, signed)
- "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
- "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
- "TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
- "TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
- "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
- "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
- "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
- "AutoContrast": (torch.tensor(0.0), False),
- "Equalize": (torch.tensor(0.0), False),
- }
- if self.all_ops:
- s.update(
- {
- "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
- "Color": (torch.linspace(0.0, 0.9, num_bins), True),
- "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
- "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
- }
- )
- return s
- @torch.jit.unused
- def _pil_to_tensor(self, img) -> Tensor:
- return F.pil_to_tensor(img)
- @torch.jit.unused
- def _tensor_to_pil(self, img: Tensor):
- return F.to_pil_image(img)
- def _sample_dirichlet(self, params: Tensor) -> Tensor:
- # Must be on a separate method so that we can overwrite it in tests.
- return torch._sample_dirichlet(params)
- def forward(self, orig_img: Tensor) -> Tensor:
- """
- img (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: Transformed image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(orig_img)
- if isinstance(orig_img, Tensor):
- img = orig_img
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- elif fill is not None:
- fill = [float(f) for f in fill]
- else:
- img = self._pil_to_tensor(orig_img)
- op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
- orig_dims = list(img.shape)
- batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
- batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
- # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
- # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
- m = self._sample_dirichlet(
- torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
- )
- # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
- combined_weights = self._sample_dirichlet(
- torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
- ) * m[:, 1].view([batch_dims[0], -1])
- mix = m[:, 0].view(batch_dims) * batch
- for i in range(self.mixture_width):
- aug = batch
- depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
- for _ in range(depth):
- op_index = int(torch.randint(len(op_meta), (1,)).item())
- op_name = list(op_meta.keys())[op_index]
- magnitudes, signed = op_meta[op_name]
- magnitude = (
- float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
- if magnitudes.ndim > 0
- else 0.0
- )
- if signed and torch.randint(2, (1,)):
- magnitude *= -1.0
- aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
- mix.add_(combined_weights[:, i].view(batch_dims) * aug)
- mix = mix.view(orig_dims).to(dtype=img.dtype)
- if not isinstance(orig_img, Tensor):
- return self._tensor_to_pil(mix)
- return mix
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}("
- f"severity={self.severity}"
- f", mixture_width={self.mixture_width}"
- f", chain_depth={self.chain_depth}"
- f", alpha={self.alpha}"
- f", all_ops={self.all_ops}"
- f", interpolation={self.interpolation}"
- f", fill={self.fill}"
- f")"
- )
- return s
|