123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160 |
- import math
- import numbers
- import random
- import warnings
- from collections.abc import Sequence
- from typing import List, Optional, Tuple, Union
- import torch
- from torch import Tensor
- try:
- import accimage
- except ImportError:
- accimage = None
- from ..utils import _log_api_usage_once
- from . import functional as F
- from .functional import _interpolation_modes_from_int, InterpolationMode
- __all__ = [
- "Compose",
- "ToTensor",
- "PILToTensor",
- "ConvertImageDtype",
- "ToPILImage",
- "Normalize",
- "Resize",
- "CenterCrop",
- "Pad",
- "Lambda",
- "RandomApply",
- "RandomChoice",
- "RandomOrder",
- "RandomCrop",
- "RandomHorizontalFlip",
- "RandomVerticalFlip",
- "RandomResizedCrop",
- "FiveCrop",
- "TenCrop",
- "LinearTransformation",
- "ColorJitter",
- "RandomRotation",
- "RandomAffine",
- "Grayscale",
- "RandomGrayscale",
- "RandomPerspective",
- "RandomErasing",
- "GaussianBlur",
- "InterpolationMode",
- "RandomInvert",
- "RandomPosterize",
- "RandomSolarize",
- "RandomAdjustSharpness",
- "RandomAutocontrast",
- "RandomEqualize",
- "ElasticTransform",
- ]
- class Compose:
- """Composes several transforms together. This transform does not support torchscript.
- Please, see the note below.
- Args:
- transforms (list of ``Transform`` objects): list of transforms to compose.
- Example:
- >>> transforms.Compose([
- >>> transforms.CenterCrop(10),
- >>> transforms.PILToTensor(),
- >>> transforms.ConvertImageDtype(torch.float),
- >>> ])
- .. note::
- In order to script the transformations, please use ``torch.nn.Sequential`` as below.
- >>> transforms = torch.nn.Sequential(
- >>> transforms.CenterCrop(10),
- >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- >>> )
- >>> scripted_transforms = torch.jit.script(transforms)
- Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
- `lambda` functions or ``PIL.Image``.
- """
- def __init__(self, transforms):
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(self)
- self.transforms = transforms
- def __call__(self, img):
- for t in self.transforms:
- img = t(img)
- return img
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- for t in self.transforms:
- format_string += "\n"
- format_string += f" {t}"
- format_string += "\n)"
- return format_string
- class ToTensor:
- """Convert a PIL Image or ndarray to tensor and scale the values accordingly.
- This transform does not support torchscript.
- Converts a PIL Image or numpy.ndarray (H x W x C) in the range
- [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
- if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
- or if the numpy.ndarray has dtype = np.uint8
- In the other cases, tensors are returned without scaling.
- .. note::
- Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
- transforming target image masks. See the `references`_ for implementing the transforms for image masks.
- .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
- """
- def __init__(self) -> None:
- _log_api_usage_once(self)
- def __call__(self, pic):
- """
- Args:
- pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
- Returns:
- Tensor: Converted image.
- """
- return F.to_tensor(pic)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}()"
- class PILToTensor:
- """Convert a PIL Image to a tensor of the same type - this does not scale values.
- This transform does not support torchscript.
- Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
- """
- def __init__(self) -> None:
- _log_api_usage_once(self)
- def __call__(self, pic):
- """
- .. note::
- A deep copy of the underlying array is performed.
- Args:
- pic (PIL Image): Image to be converted to tensor.
- Returns:
- Tensor: Converted image.
- """
- return F.pil_to_tensor(pic)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}()"
- class ConvertImageDtype(torch.nn.Module):
- """Convert a tensor image to the given ``dtype`` and scale the values accordingly.
- This function does not support PIL Image.
- Args:
- dtype (torch.dtype): Desired data type of the output
- .. note::
- When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
- If converted back and forth, this mismatch has no effect.
- Raises:
- RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
- well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
- overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
- of the integer ``dtype``.
- """
- def __init__(self, dtype: torch.dtype) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.dtype = dtype
- def forward(self, image):
- return F.convert_image_dtype(image, self.dtype)
- class ToPILImage:
- """Convert a tensor or an ndarray to PIL Image
- This transform does not support torchscript.
- Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
- H x W x C to a PIL Image while adjusting the value range depending on the ``mode``.
- Args:
- mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
- If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``).
- .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
- """
- def __init__(self, mode=None):
- _log_api_usage_once(self)
- self.mode = mode
- def __call__(self, pic):
- """
- Args:
- pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
- Returns:
- PIL Image: Image converted to PIL Image.
- """
- return F.to_pil_image(pic, self.mode)
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- if self.mode is not None:
- format_string += f"mode={self.mode}"
- format_string += ")"
- return format_string
- class Normalize(torch.nn.Module):
- """Normalize a tensor image with mean and standard deviation.
- This transform does not support PIL Image.
- Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
- channels, this transform will normalize each channel of the input
- ``torch.*Tensor`` i.e.,
- ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
- .. note::
- This transform acts out of place, i.e., it does not mutate the input tensor.
- Args:
- mean (sequence): Sequence of means for each channel.
- std (sequence): Sequence of standard deviations for each channel.
- inplace(bool,optional): Bool to make this operation in-place.
- """
- def __init__(self, mean, std, inplace=False):
- super().__init__()
- _log_api_usage_once(self)
- self.mean = mean
- self.std = std
- self.inplace = inplace
- def forward(self, tensor: Tensor) -> Tensor:
- """
- Args:
- tensor (Tensor): Tensor image to be normalized.
- Returns:
- Tensor: Normalized Tensor image.
- """
- return F.normalize(tensor, self.mean, self.std, self.inplace)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
- class Resize(torch.nn.Module):
- """Resize the input image to the given size.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means a maximum of two leading dimensions
- .. warning::
- The output image might be different depending on its type: when downsampling, the interpolation of PIL images
- and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
- in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
- types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
- closer.
- Args:
- size (sequence or int): Desired output size. If size is a sequence like
- (h, w), output size will be matched to this. If size is an int,
- smaller edge of the image will be matched to this number.
- i.e, if height > width, then image will be rescaled to
- (size * height / width, size).
- .. note::
- In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
- ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- max_size (int, optional): The maximum allowed for the longer edge of
- the resized image. If the longer edge of the image is greater
- than ``max_size`` after being resized according to ``size``,
- ``size`` will be overruled so that the longer edge is equal to
- ``max_size``.
- As a result, the smaller edge may be shorter than ``size``. This
- is only supported if ``size`` is an int (or a sequence of length
- 1 in torchscript mode).
- antialias (bool, optional): Whether to apply antialiasing.
- It only affects **tensors** with bilinear or bicubic modes and it is
- ignored otherwise: on PIL images, antialiasing is always applied on
- bilinear or bicubic modes; on other modes (for PIL images and
- tensors), antialiasing makes no sense and this parameter is ignored.
- Possible values are:
- - ``True``: will apply antialiasing for bilinear or bicubic modes.
- Other mode aren't affected. This is probably what you want to use.
- - ``False``: will not apply antialiasing for tensors on any mode. PIL
- images are still antialiased on bilinear or bicubic modes, because
- PIL doesn't support no antialias.
- - ``None``: equivalent to ``False`` for tensors and ``True`` for
- PIL images. This value exists for legacy reasons and you probably
- don't want to use it unless you really know what you are doing.
- The current default is ``None`` **but will change to** ``True`` **in
- v0.17** for the PIL and Tensor backends to be consistent.
- """
- def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(size, (int, Sequence)):
- raise TypeError(f"Size should be int or sequence. Got {type(size)}")
- if isinstance(size, Sequence) and len(size) not in (1, 2):
- raise ValueError("If size is a sequence, it should have 1 or 2 values")
- self.size = size
- self.max_size = max_size
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- self.antialias = antialias
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be scaled.
- Returns:
- PIL Image or Tensor: Rescaled image.
- """
- return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
- def __repr__(self) -> str:
- detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
- return f"{self.__class__.__name__}{detail}"
- class CenterCrop(torch.nn.Module):
- """Crops the given image at the center.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
- Args:
- size (sequence or int): Desired output size of the crop. If size is an
- int instead of sequence like (h, w), a square crop (size, size) is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- """
- def __init__(self, size):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- PIL Image or Tensor: Cropped image.
- """
- return F.center_crop(img, self.size)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size})"
- class Pad(torch.nn.Module):
- """Pad the given image on all sides with the given "pad" value.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
- at most 3 leading dimensions for mode edge,
- and an arbitrary number of leading dimensions for mode constant
- Args:
- padding (int or sequence): Padding on each border. If a single int is provided this
- is used to pad all borders. If sequence of length 2 is provided this is the padding
- on left/right and top/bottom respectively. If a sequence of length 4 is provided
- this is the padding for the left, top, right and bottom borders respectively.
- .. note::
- In torchscript mode padding as single int is not supported, use a sequence of
- length 1: ``[padding, ]``.
- fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
- length 3, it is used to fill R, G, B channels respectively.
- This value is only used when the padding_mode is constant.
- Only number is supported for torch Tensor.
- Only int or tuple value is supported for PIL Image.
- padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
- Default is constant.
- - constant: pads with a constant value, this value is specified with fill
- - edge: pads with the last value at the edge of the image.
- If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- - reflect: pads with reflection of image without repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
- will result in [3, 2, 1, 2, 3, 4, 3, 2]
- - symmetric: pads with reflection of image repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
- will result in [2, 1, 1, 2, 3, 4, 4, 3]
- """
- def __init__(self, padding, fill=0, padding_mode="constant"):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(padding, (numbers.Number, tuple, list)):
- raise TypeError("Got inappropriate padding arg")
- if not isinstance(fill, (numbers.Number, tuple, list)):
- raise TypeError("Got inappropriate fill arg")
- if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
- raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
- if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
- raise ValueError(
- f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
- )
- self.padding = padding
- self.fill = fill
- self.padding_mode = padding_mode
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be padded.
- Returns:
- PIL Image or Tensor: Padded image.
- """
- return F.pad(img, self.padding, self.fill, self.padding_mode)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
- class Lambda:
- """Apply a user-defined lambda as a transform. This transform does not support torchscript.
- Args:
- lambd (function): Lambda/function to be used for transform.
- """
- def __init__(self, lambd):
- _log_api_usage_once(self)
- if not callable(lambd):
- raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
- self.lambd = lambd
- def __call__(self, img):
- return self.lambd(img)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}()"
- class RandomTransforms:
- """Base class for a list of transformations with randomness
- Args:
- transforms (sequence): list of transformations
- """
- def __init__(self, transforms):
- _log_api_usage_once(self)
- if not isinstance(transforms, Sequence):
- raise TypeError("Argument transforms should be a sequence")
- self.transforms = transforms
- def __call__(self, *args, **kwargs):
- raise NotImplementedError()
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- for t in self.transforms:
- format_string += "\n"
- format_string += f" {t}"
- format_string += "\n)"
- return format_string
- class RandomApply(torch.nn.Module):
- """Apply randomly a list of transformations with a given probability.
- .. note::
- In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
- transforms as shown below:
- >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
- >>> transforms.ColorJitter(),
- >>> ]), p=0.3)
- >>> scripted_transforms = torch.jit.script(transforms)
- Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
- `lambda` functions or ``PIL.Image``.
- Args:
- transforms (sequence or torch.nn.Module): list of transformations
- p (float): probability
- """
- def __init__(self, transforms, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.transforms = transforms
- self.p = p
- def forward(self, img):
- if self.p < torch.rand(1):
- return img
- for t in self.transforms:
- img = t(img)
- return img
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- format_string += f"\n p={self.p}"
- for t in self.transforms:
- format_string += "\n"
- format_string += f" {t}"
- format_string += "\n)"
- return format_string
- class RandomOrder(RandomTransforms):
- """Apply a list of transformations in a random order. This transform does not support torchscript."""
- def __call__(self, img):
- order = list(range(len(self.transforms)))
- random.shuffle(order)
- for i in order:
- img = self.transforms[i](img)
- return img
- class RandomChoice(RandomTransforms):
- """Apply single transformation randomly picked from a list. This transform does not support torchscript."""
- def __init__(self, transforms, p=None):
- super().__init__(transforms)
- if p is not None and not isinstance(p, Sequence):
- raise TypeError("Argument p should be a sequence")
- self.p = p
- def __call__(self, *args):
- t = random.choices(self.transforms, weights=self.p)[0]
- return t(*args)
- def __repr__(self) -> str:
- return f"{super().__repr__()}(p={self.p})"
- class RandomCrop(torch.nn.Module):
- """Crop the given image at a random location.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
- but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
- Args:
- size (sequence or int): Desired output size of the crop. If size is an
- int instead of sequence like (h, w), a square crop (size, size) is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- padding (int or sequence, optional): Optional padding on each border
- of the image. Default is None. If a single int is provided this
- is used to pad all borders. If sequence of length 2 is provided this is the padding
- on left/right and top/bottom respectively. If a sequence of length 4 is provided
- this is the padding for the left, top, right and bottom borders respectively.
- .. note::
- In torchscript mode padding as single int is not supported, use a sequence of
- length 1: ``[padding, ]``.
- pad_if_needed (boolean): It will pad the image if smaller than the
- desired size to avoid raising an exception. Since cropping is done
- after padding, the padding seems to be done at a random offset.
- fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
- length 3, it is used to fill R, G, B channels respectively.
- This value is only used when the padding_mode is constant.
- Only number is supported for torch Tensor.
- Only int or tuple value is supported for PIL Image.
- padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
- Default is constant.
- - constant: pads with a constant value, this value is specified with fill
- - edge: pads with the last value at the edge of the image.
- If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- - reflect: pads with reflection of image without repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
- will result in [3, 2, 1, 2, 3, 4, 3, 2]
- - symmetric: pads with reflection of image repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
- will result in [2, 1, 1, 2, 3, 4, 4, 3]
- """
- @staticmethod
- def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
- """Get parameters for ``crop`` for a random crop.
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- output_size (tuple): Expected output size of the crop.
- Returns:
- tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
- """
- _, h, w = F.get_dimensions(img)
- th, tw = output_size
- if h < th or w < tw:
- raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
- if w == tw and h == th:
- return 0, 0, h, w
- i = torch.randint(0, h - th + 1, size=(1,)).item()
- j = torch.randint(0, w - tw + 1, size=(1,)).item()
- return i, j, th, tw
- def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
- super().__init__()
- _log_api_usage_once(self)
- self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
- self.padding = padding
- self.pad_if_needed = pad_if_needed
- self.fill = fill
- self.padding_mode = padding_mode
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- PIL Image or Tensor: Cropped image.
- """
- if self.padding is not None:
- img = F.pad(img, self.padding, self.fill, self.padding_mode)
- _, height, width = F.get_dimensions(img)
- # pad the width if needed
- if self.pad_if_needed and width < self.size[1]:
- padding = [self.size[1] - width, 0]
- img = F.pad(img, padding, self.fill, self.padding_mode)
- # pad the height if needed
- if self.pad_if_needed and height < self.size[0]:
- padding = [0, self.size[0] - height]
- img = F.pad(img, padding, self.fill, self.padding_mode)
- i, j, h, w = self.get_params(img, self.size)
- return F.crop(img, i, j, h, w)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
- class RandomHorizontalFlip(torch.nn.Module):
- """Horizontally flip the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- Args:
- p (float): probability of the image being flipped. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be flipped.
- Returns:
- PIL Image or Tensor: Randomly flipped image.
- """
- if torch.rand(1) < self.p:
- return F.hflip(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomVerticalFlip(torch.nn.Module):
- """Vertically flip the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- Args:
- p (float): probability of the image being flipped. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be flipped.
- Returns:
- PIL Image or Tensor: Randomly flipped image.
- """
- if torch.rand(1) < self.p:
- return F.vflip(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomPerspective(torch.nn.Module):
- """Performs a random perspective transformation of the given image with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
- Default is 0.5.
- p (float): probability of the image being transformed. Default is 0.5.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- fill (sequence or number): Pixel fill value for the area outside the transformed
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- """
- def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- self.distortion_scale = distortion_scale
- if fill is None:
- fill = 0
- elif not isinstance(fill, (Sequence, numbers.Number)):
- raise TypeError("Fill should be either a sequence or a number.")
- self.fill = fill
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be Perspectively transformed.
- Returns:
- PIL Image or Tensor: Randomly transformed image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- else:
- fill = [float(f) for f in fill]
- if torch.rand(1) < self.p:
- startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
- return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
- return img
- @staticmethod
- def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
- """Get parameters for ``perspective`` for a random perspective transform.
- Args:
- width (int): width of the image.
- height (int): height of the image.
- distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
- Returns:
- List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
- List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
- """
- half_height = height // 2
- half_width = width // 2
- topleft = [
- int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
- int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
- ]
- topright = [
- int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
- int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
- ]
- botright = [
- int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
- int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
- ]
- botleft = [
- int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
- int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
- ]
- startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
- endpoints = [topleft, topright, botright, botleft]
- return startpoints, endpoints
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomResizedCrop(torch.nn.Module):
- """Crop a random portion of image and resize it to a given size.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
- A crop of the original image is made: the crop has a random area (H * W)
- and a random aspect ratio. This crop is finally resized to the given
- size. This is popularly used to train the Inception networks.
- Args:
- size (int or sequence): expected output size of the crop, for each edge. If size is an
- int instead of sequence like (h, w), a square output size ``(size, size)`` is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- .. note::
- In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
- scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
- before resizing. The scale is defined with respect to the area of the original image.
- ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
- resizing.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
- ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- antialias (bool, optional): Whether to apply antialiasing.
- It only affects **tensors** with bilinear or bicubic modes and it is
- ignored otherwise: on PIL images, antialiasing is always applied on
- bilinear or bicubic modes; on other modes (for PIL images and
- tensors), antialiasing makes no sense and this parameter is ignored.
- Possible values are:
- - ``True``: will apply antialiasing for bilinear or bicubic modes.
- Other mode aren't affected. This is probably what you want to use.
- - ``False``: will not apply antialiasing for tensors on any mode. PIL
- images are still antialiased on bilinear or bicubic modes, because
- PIL doesn't support no antialias.
- - ``None``: equivalent to ``False`` for tensors and ``True`` for
- PIL images. This value exists for legacy reasons and you probably
- don't want to use it unless you really know what you are doing.
- The current default is ``None`` **but will change to** ``True`` **in
- v0.17** for the PIL and Tensor backends to be consistent.
- """
- def __init__(
- self,
- size,
- scale=(0.08, 1.0),
- ratio=(3.0 / 4.0, 4.0 / 3.0),
- interpolation=InterpolationMode.BILINEAR,
- antialias: Optional[Union[str, bool]] = "warn",
- ):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- if not isinstance(scale, Sequence):
- raise TypeError("Scale should be a sequence")
- if not isinstance(ratio, Sequence):
- raise TypeError("Ratio should be a sequence")
- if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
- warnings.warn("Scale and ratio should be of kind (min, max)")
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- self.antialias = antialias
- self.scale = scale
- self.ratio = ratio
- @staticmethod
- def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
- """Get parameters for ``crop`` for a random sized crop.
- Args:
- img (PIL Image or Tensor): Input image.
- scale (list): range of scale of the origin size cropped
- ratio (list): range of aspect ratio of the origin aspect ratio cropped
- Returns:
- tuple: params (i, j, h, w) to be passed to ``crop`` for a random
- sized crop.
- """
- _, height, width = F.get_dimensions(img)
- area = height * width
- log_ratio = torch.log(torch.tensor(ratio))
- for _ in range(10):
- target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
- aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
- w = int(round(math.sqrt(target_area * aspect_ratio)))
- h = int(round(math.sqrt(target_area / aspect_ratio)))
- if 0 < w <= width and 0 < h <= height:
- i = torch.randint(0, height - h + 1, size=(1,)).item()
- j = torch.randint(0, width - w + 1, size=(1,)).item()
- return i, j, h, w
- # Fallback to central crop
- in_ratio = float(width) / float(height)
- if in_ratio < min(ratio):
- w = width
- h = int(round(w / min(ratio)))
- elif in_ratio > max(ratio):
- h = height
- w = int(round(h * max(ratio)))
- else: # whole image
- w = width
- h = height
- i = (height - h) // 2
- j = (width - w) // 2
- return i, j, h, w
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped and resized.
- Returns:
- PIL Image or Tensor: Randomly cropped and resized image.
- """
- i, j, h, w = self.get_params(img, self.scale, self.ratio)
- return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
- def __repr__(self) -> str:
- interpolate_str = self.interpolation.value
- format_string = self.__class__.__name__ + f"(size={self.size}"
- format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
- format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
- format_string += f", interpolation={interpolate_str}"
- format_string += f", antialias={self.antialias})"
- return format_string
- class FiveCrop(torch.nn.Module):
- """Crop the given image into four corners and the central crop.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- .. Note::
- This transform returns a tuple of images and there may be a mismatch in the number of
- inputs and targets your Dataset returns. See below for an example of how to deal with
- this.
- Args:
- size (sequence or int): Desired output size of the crop. If size is an ``int``
- instead of sequence like (h, w), a square crop of size (size, size) is made.
- If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- Example:
- >>> transform = Compose([
- >>> FiveCrop(size), # this is a list of PIL Images
- >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
- >>> ])
- >>> #In your test loop you can do the following:
- >>> input, target = batch # input is a 5d tensor, target is 2d
- >>> bs, ncrops, c, h, w = input.size()
- >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
- >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
- """
- def __init__(self, size):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- tuple of 5 images. Image can be PIL Image or Tensor
- """
- return F.five_crop(img, self.size)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size})"
- class TenCrop(torch.nn.Module):
- """Crop the given image into four corners and the central crop plus the flipped version of
- these (horizontal flipping is used by default).
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- .. Note::
- This transform returns a tuple of images and there may be a mismatch in the number of
- inputs and targets your Dataset returns. See below for an example of how to deal with
- this.
- Args:
- size (sequence or int): Desired output size of the crop. If size is an
- int instead of sequence like (h, w), a square crop (size, size) is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- vertical_flip (bool): Use vertical flipping instead of horizontal
- Example:
- >>> transform = Compose([
- >>> TenCrop(size), # this is a tuple of PIL Images
- >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
- >>> ])
- >>> #In your test loop you can do the following:
- >>> input, target = batch # input is a 5d tensor, target is 2d
- >>> bs, ncrops, c, h, w = input.size()
- >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
- >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
- """
- def __init__(self, size, vertical_flip=False):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- self.vertical_flip = vertical_flip
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- tuple of 10 images. Image can be PIL Image or Tensor
- """
- return F.ten_crop(img, self.size, self.vertical_flip)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
- class LinearTransformation(torch.nn.Module):
- """Transform a tensor image with a square transformation matrix and a mean_vector computed
- offline.
- This transform does not support PIL Image.
- Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
- subtract mean_vector from it which is then followed by computing the dot
- product with the transformation matrix and then reshaping the tensor to its
- original shape.
- Applications:
- whitening transformation: Suppose X is a column vector zero-centered data.
- Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
- perform SVD on this matrix and pass it as transformation_matrix.
- Args:
- transformation_matrix (Tensor): tensor [D x D], D = C x H x W
- mean_vector (Tensor): tensor [D], D = C x H x W
- """
- def __init__(self, transformation_matrix, mean_vector):
- super().__init__()
- _log_api_usage_once(self)
- if transformation_matrix.size(0) != transformation_matrix.size(1):
- raise ValueError(
- "transformation_matrix should be square. Got "
- f"{tuple(transformation_matrix.size())} rectangular matrix."
- )
- if mean_vector.size(0) != transformation_matrix.size(0):
- raise ValueError(
- f"mean_vector should have the same length {mean_vector.size(0)}"
- f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
- )
- if transformation_matrix.device != mean_vector.device:
- raise ValueError(
- f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
- )
- if transformation_matrix.dtype != mean_vector.dtype:
- raise ValueError(
- f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
- )
- self.transformation_matrix = transformation_matrix
- self.mean_vector = mean_vector
- def forward(self, tensor: Tensor) -> Tensor:
- """
- Args:
- tensor (Tensor): Tensor image to be whitened.
- Returns:
- Tensor: Transformed image.
- """
- shape = tensor.shape
- n = shape[-3] * shape[-2] * shape[-1]
- if n != self.transformation_matrix.shape[0]:
- raise ValueError(
- "Input tensor and transformation matrix have incompatible shape."
- + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
- + f"{self.transformation_matrix.shape[0]}"
- )
- if tensor.device.type != self.mean_vector.device.type:
- raise ValueError(
- "Input tensor should be on the same device as transformation matrix and mean vector. "
- f"Got {tensor.device} vs {self.mean_vector.device}"
- )
- flat_tensor = tensor.view(-1, n) - self.mean_vector
- transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
- transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
- tensor = transformed_tensor.view(shape)
- return tensor
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}(transformation_matrix="
- f"{self.transformation_matrix.tolist()}"
- f", mean_vector={self.mean_vector.tolist()})"
- )
- return s
- class ColorJitter(torch.nn.Module):
- """Randomly change the brightness, contrast, saturation and hue of an image.
- If the image is torch Tensor, it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
- Args:
- brightness (float or tuple of float (min, max)): How much to jitter brightness.
- brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
- or the given [min, max]. Should be non negative numbers.
- contrast (float or tuple of float (min, max)): How much to jitter contrast.
- contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
- or the given [min, max]. Should be non-negative numbers.
- saturation (float or tuple of float (min, max)): How much to jitter saturation.
- saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
- or the given [min, max]. Should be non negative numbers.
- hue (float or tuple of float (min, max)): How much to jitter hue.
- hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
- Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
- To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
- thus it does not work if you normalize your image to an interval with negative values,
- or use an interpolation that generates negative values before using this function.
- """
- def __init__(
- self,
- brightness: Union[float, Tuple[float, float]] = 0,
- contrast: Union[float, Tuple[float, float]] = 0,
- saturation: Union[float, Tuple[float, float]] = 0,
- hue: Union[float, Tuple[float, float]] = 0,
- ) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.brightness = self._check_input(brightness, "brightness")
- self.contrast = self._check_input(contrast, "contrast")
- self.saturation = self._check_input(saturation, "saturation")
- self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
- @torch.jit.unused
- def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
- if isinstance(value, numbers.Number):
- if value < 0:
- raise ValueError(f"If {name} is a single number, it must be non negative.")
- value = [center - float(value), center + float(value)]
- if clip_first_on_zero:
- value[0] = max(value[0], 0.0)
- elif isinstance(value, (tuple, list)) and len(value) == 2:
- value = [float(value[0]), float(value[1])]
- else:
- raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
- if not bound[0] <= value[0] <= value[1] <= bound[1]:
- raise ValueError(f"{name} values should be between {bound}, but got {value}.")
- # if value is 0 or (1., 1.) for brightness/contrast/saturation
- # or (0., 0.) for hue, do nothing
- if value[0] == value[1] == center:
- return None
- else:
- return tuple(value)
- @staticmethod
- def get_params(
- brightness: Optional[List[float]],
- contrast: Optional[List[float]],
- saturation: Optional[List[float]],
- hue: Optional[List[float]],
- ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
- """Get the parameters for the randomized transform to be applied on image.
- Args:
- brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
- uniformly. Pass None to turn off the transformation.
- contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
- uniformly. Pass None to turn off the transformation.
- saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
- uniformly. Pass None to turn off the transformation.
- hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
- Pass None to turn off the transformation.
- Returns:
- tuple: The parameters used to apply the randomized transform
- along with their random order.
- """
- fn_idx = torch.randperm(4)
- b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
- c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
- s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
- h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
- return fn_idx, b, c, s, h
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Input image.
- Returns:
- PIL Image or Tensor: Color jittered image.
- """
- fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
- self.brightness, self.contrast, self.saturation, self.hue
- )
- for fn_id in fn_idx:
- if fn_id == 0 and brightness_factor is not None:
- img = F.adjust_brightness(img, brightness_factor)
- elif fn_id == 1 and contrast_factor is not None:
- img = F.adjust_contrast(img, contrast_factor)
- elif fn_id == 2 and saturation_factor is not None:
- img = F.adjust_saturation(img, saturation_factor)
- elif fn_id == 3 and hue_factor is not None:
- img = F.adjust_hue(img, hue_factor)
- return img
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}("
- f"brightness={self.brightness}"
- f", contrast={self.contrast}"
- f", saturation={self.saturation}"
- f", hue={self.hue})"
- )
- return s
- class RandomRotation(torch.nn.Module):
- """Rotate the image by angle.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- degrees (sequence or number): Range of degrees to select from.
- If degrees is a number instead of sequence like (min, max), the range of degrees
- will be (-degrees, +degrees).
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- expand (bool, optional): Optional expansion flag.
- If true, expands the output to make it large enough to hold the entire rotated image.
- If false or omitted, make the output image the same size as the input image.
- Note that the expand flag assumes rotation around the center and no translation.
- center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
- Default is the center of the image.
- fill (sequence or number): Pixel fill value for the area outside the rotated
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
- """
- def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
- super().__init__()
- _log_api_usage_once(self)
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
- if center is not None:
- _check_sequence_input(center, "center", req_sizes=(2,))
- self.center = center
- self.interpolation = interpolation
- self.expand = expand
- if fill is None:
- fill = 0
- elif not isinstance(fill, (Sequence, numbers.Number)):
- raise TypeError("Fill should be either a sequence or a number.")
- self.fill = fill
- @staticmethod
- def get_params(degrees: List[float]) -> float:
- """Get parameters for ``rotate`` for a random rotation.
- Returns:
- float: angle parameter to be passed to ``rotate`` for random rotation.
- """
- angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
- return angle
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be rotated.
- Returns:
- PIL Image or Tensor: Rotated image.
- """
- fill = self.fill
- channels, _, _ = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- else:
- fill = [float(f) for f in fill]
- angle = self.get_params(self.degrees)
- return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
- def __repr__(self) -> str:
- interpolate_str = self.interpolation.value
- format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
- format_string += f", interpolation={interpolate_str}"
- format_string += f", expand={self.expand}"
- if self.center is not None:
- format_string += f", center={self.center}"
- if self.fill is not None:
- format_string += f", fill={self.fill}"
- format_string += ")"
- return format_string
- class RandomAffine(torch.nn.Module):
- """Random affine transformation of the image keeping center invariant.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- degrees (sequence or number): Range of degrees to select from.
- If degrees is a number instead of sequence like (min, max), the range of degrees
- will be (-degrees, +degrees). Set to 0 to deactivate rotations.
- translate (tuple, optional): tuple of maximum absolute fraction for horizontal
- and vertical translations. For example translate=(a, b), then horizontal shift
- is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
- randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
- scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
- randomly sampled from the range a <= scale <= b. Will keep original scale by default.
- shear (sequence or number, optional): Range of degrees to select from.
- If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
- will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
- range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
- an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
- Will not apply shear by default.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- fill (sequence or number): Pixel fill value for the area outside the transformed
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
- Default is the center of the image.
- .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
- """
- def __init__(
- self,
- degrees,
- translate=None,
- scale=None,
- shear=None,
- interpolation=InterpolationMode.NEAREST,
- fill=0,
- center=None,
- ):
- super().__init__()
- _log_api_usage_once(self)
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
- if translate is not None:
- _check_sequence_input(translate, "translate", req_sizes=(2,))
- for t in translate:
- if not (0.0 <= t <= 1.0):
- raise ValueError("translation values should be between 0 and 1")
- self.translate = translate
- if scale is not None:
- _check_sequence_input(scale, "scale", req_sizes=(2,))
- for s in scale:
- if s <= 0:
- raise ValueError("scale values should be positive")
- self.scale = scale
- if shear is not None:
- self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
- else:
- self.shear = shear
- self.interpolation = interpolation
- if fill is None:
- fill = 0
- elif not isinstance(fill, (Sequence, numbers.Number)):
- raise TypeError("Fill should be either a sequence or a number.")
- self.fill = fill
- if center is not None:
- _check_sequence_input(center, "center", req_sizes=(2,))
- self.center = center
- @staticmethod
- def get_params(
- degrees: List[float],
- translate: Optional[List[float]],
- scale_ranges: Optional[List[float]],
- shears: Optional[List[float]],
- img_size: List[int],
- ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
- """Get parameters for affine transformation
- Returns:
- params to be passed to the affine transformation
- """
- angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
- if translate is not None:
- max_dx = float(translate[0] * img_size[0])
- max_dy = float(translate[1] * img_size[1])
- tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
- ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
- translations = (tx, ty)
- else:
- translations = (0, 0)
- if scale_ranges is not None:
- scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
- else:
- scale = 1.0
- shear_x = shear_y = 0.0
- if shears is not None:
- shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
- if len(shears) == 4:
- shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
- shear = (shear_x, shear_y)
- return angle, translations, scale, shear
- def forward(self, img):
- """
- img (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: Affine transformed image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- else:
- fill = [float(f) for f in fill]
- img_size = [width, height] # flip for keeping BC on get_params call
- ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
- return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
- def __repr__(self) -> str:
- s = f"{self.__class__.__name__}(degrees={self.degrees}"
- s += f", translate={self.translate}" if self.translate is not None else ""
- s += f", scale={self.scale}" if self.scale is not None else ""
- s += f", shear={self.shear}" if self.shear is not None else ""
- s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
- s += f", fill={self.fill}" if self.fill != 0 else ""
- s += f", center={self.center}" if self.center is not None else ""
- s += ")"
- return s
- class Grayscale(torch.nn.Module):
- """Convert image to grayscale.
- If the image is torch Tensor, it is expected
- to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
- Args:
- num_output_channels (int): (1 or 3) number of channels desired for output image
- Returns:
- PIL Image: Grayscale version of the input.
- - If ``num_output_channels == 1`` : returned image is single channel
- - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
- """
- def __init__(self, num_output_channels=1):
- super().__init__()
- _log_api_usage_once(self)
- self.num_output_channels = num_output_channels
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be converted to grayscale.
- Returns:
- PIL Image or Tensor: Grayscaled image.
- """
- return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
- class RandomGrayscale(torch.nn.Module):
- """Randomly convert image to grayscale with a probability of p (default 0.1).
- If the image is torch Tensor, it is expected
- to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
- Args:
- p (float): probability that image should be converted to grayscale.
- Returns:
- PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
- with probability (1-p).
- - If input image is 1 channel: grayscale version is 1 channel
- - If input image is 3 channel: grayscale version is 3 channel with r == g == b
- """
- def __init__(self, p=0.1):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be converted to grayscale.
- Returns:
- PIL Image or Tensor: Randomly grayscaled image.
- """
- num_output_channels, _, _ = F.get_dimensions(img)
- if torch.rand(1) < self.p:
- return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomErasing(torch.nn.Module):
- """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
- This transform does not support PIL Image.
- 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
- Args:
- p: probability that the random erasing operation will be performed.
- scale: range of proportion of erased area against input image.
- ratio: range of aspect ratio of erased area.
- value: erasing value. Default is 0. If a single int, it is used to
- erase all pixels. If a tuple of length 3, it is used to erase
- R, G, B channels respectively.
- If a str of 'random', erasing each pixel with random values.
- inplace: boolean to make this transform inplace. Default set to False.
- Returns:
- Erased Image.
- Example:
- >>> transform = transforms.Compose([
- >>> transforms.RandomHorizontalFlip(),
- >>> transforms.PILToTensor(),
- >>> transforms.ConvertImageDtype(torch.float),
- >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- >>> transforms.RandomErasing(),
- >>> ])
- """
- def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(value, (numbers.Number, str, tuple, list)):
- raise TypeError("Argument value should be either a number or str or a sequence")
- if isinstance(value, str) and value != "random":
- raise ValueError("If value is str, it should be 'random'")
- if not isinstance(scale, (tuple, list)):
- raise TypeError("Scale should be a sequence")
- if not isinstance(ratio, (tuple, list)):
- raise TypeError("Ratio should be a sequence")
- if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
- warnings.warn("Scale and ratio should be of kind (min, max)")
- if scale[0] < 0 or scale[1] > 1:
- raise ValueError("Scale should be between 0 and 1")
- if p < 0 or p > 1:
- raise ValueError("Random erasing probability should be between 0 and 1")
- self.p = p
- self.scale = scale
- self.ratio = ratio
- self.value = value
- self.inplace = inplace
- @staticmethod
- def get_params(
- img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
- ) -> Tuple[int, int, int, int, Tensor]:
- """Get parameters for ``erase`` for a random erasing.
- Args:
- img (Tensor): Tensor image to be erased.
- scale (sequence): range of proportion of erased area against input image.
- ratio (sequence): range of aspect ratio of erased area.
- value (list, optional): erasing value. If None, it is interpreted as "random"
- (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
- i.e. ``value[0]``.
- Returns:
- tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
- """
- img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
- area = img_h * img_w
- log_ratio = torch.log(torch.tensor(ratio))
- for _ in range(10):
- erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
- aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
- h = int(round(math.sqrt(erase_area * aspect_ratio)))
- w = int(round(math.sqrt(erase_area / aspect_ratio)))
- if not (h < img_h and w < img_w):
- continue
- if value is None:
- v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
- else:
- v = torch.tensor(value)[:, None, None]
- i = torch.randint(0, img_h - h + 1, size=(1,)).item()
- j = torch.randint(0, img_w - w + 1, size=(1,)).item()
- return i, j, h, w, v
- # Return original image
- return 0, 0, img_h, img_w, img
- def forward(self, img):
- """
- Args:
- img (Tensor): Tensor image to be erased.
- Returns:
- img (Tensor): Erased Tensor image.
- """
- if torch.rand(1) < self.p:
- # cast self.value to script acceptable type
- if isinstance(self.value, (int, float)):
- value = [float(self.value)]
- elif isinstance(self.value, str):
- value = None
- elif isinstance(self.value, (list, tuple)):
- value = [float(v) for v in self.value]
- else:
- value = self.value
- if value is not None and not (len(value) in (1, img.shape[-3])):
- raise ValueError(
- "If value is a sequence, it should have either a single value or "
- f"{img.shape[-3]} (number of input channels)"
- )
- x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
- return F.erase(img, x, y, h, w, v, self.inplace)
- return img
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}"
- f"(p={self.p}, "
- f"scale={self.scale}, "
- f"ratio={self.ratio}, "
- f"value={self.value}, "
- f"inplace={self.inplace})"
- )
- return s
- class GaussianBlur(torch.nn.Module):
- """Blurs image with randomly chosen Gaussian blur.
- If the image is torch Tensor, it is expected
- to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- kernel_size (int or sequence): Size of the Gaussian kernel.
- sigma (float or tuple of float (min, max)): Standard deviation to be used for
- creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
- of float (min, max), sigma is chosen uniformly at random to lie in the
- given range.
- Returns:
- PIL Image or Tensor: Gaussian blurred version of the input image.
- """
- def __init__(self, kernel_size, sigma=(0.1, 2.0)):
- super().__init__()
- _log_api_usage_once(self)
- self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
- for ks in self.kernel_size:
- if ks <= 0 or ks % 2 == 0:
- raise ValueError("Kernel size value should be an odd and positive number.")
- if isinstance(sigma, numbers.Number):
- if sigma <= 0:
- raise ValueError("If sigma is a single number, it must be positive.")
- sigma = (sigma, sigma)
- elif isinstance(sigma, Sequence) and len(sigma) == 2:
- if not 0.0 < sigma[0] <= sigma[1]:
- raise ValueError("sigma values should be positive and of the form (min, max).")
- else:
- raise ValueError("sigma should be a single number or a list/tuple with length 2.")
- self.sigma = sigma
- @staticmethod
- def get_params(sigma_min: float, sigma_max: float) -> float:
- """Choose sigma for random gaussian blurring.
- Args:
- sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
- sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
- Returns:
- float: Standard deviation to be passed to calculate kernel for gaussian blurring.
- """
- return torch.empty(1).uniform_(sigma_min, sigma_max).item()
- def forward(self, img: Tensor) -> Tensor:
- """
- Args:
- img (PIL Image or Tensor): image to be blurred.
- Returns:
- PIL Image or Tensor: Gaussian blurred image
- """
- sigma = self.get_params(self.sigma[0], self.sigma[1])
- return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
- def __repr__(self) -> str:
- s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
- return s
- def _setup_size(size, error_msg):
- if isinstance(size, numbers.Number):
- return int(size), int(size)
- if isinstance(size, Sequence) and len(size) == 1:
- return size[0], size[0]
- if len(size) != 2:
- raise ValueError(error_msg)
- return size
- def _check_sequence_input(x, name, req_sizes):
- msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
- if not isinstance(x, Sequence):
- raise TypeError(f"{name} should be a sequence of length {msg}.")
- if len(x) not in req_sizes:
- raise ValueError(f"{name} should be a sequence of length {msg}.")
- def _setup_angle(x, name, req_sizes=(2,)):
- if isinstance(x, numbers.Number):
- if x < 0:
- raise ValueError(f"If {name} is a single number, it must be positive.")
- x = [-x, x]
- else:
- _check_sequence_input(x, name, req_sizes)
- return [float(d) for d in x]
- class RandomInvert(torch.nn.Module):
- """Inverts the colors of the given image randomly with a given probability.
- If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
- where ... means it can have an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- p (float): probability of the image being color inverted. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be inverted.
- Returns:
- PIL Image or Tensor: Randomly color inverted image.
- """
- if torch.rand(1).item() < self.p:
- return F.invert(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomPosterize(torch.nn.Module):
- """Posterize the image randomly with a given probability by reducing the
- number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
- and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- bits (int): number of bits to keep for each channel (0-8)
- p (float): probability of the image being posterized. Default value is 0.5
- """
- def __init__(self, bits, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.bits = bits
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be posterized.
- Returns:
- PIL Image or Tensor: Randomly posterized image.
- """
- if torch.rand(1).item() < self.p:
- return F.posterize(img, self.bits)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
- class RandomSolarize(torch.nn.Module):
- """Solarize the image randomly with a given probability by inverting all pixel
- values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
- where ... means it can have an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- threshold (float): all pixels equal or above this value are inverted.
- p (float): probability of the image being solarized. Default value is 0.5
- """
- def __init__(self, threshold, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.threshold = threshold
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be solarized.
- Returns:
- PIL Image or Tensor: Randomly solarized image.
- """
- if torch.rand(1).item() < self.p:
- return F.solarize(img, self.threshold)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
- class RandomAdjustSharpness(torch.nn.Module):
- """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
- it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- sharpness_factor (float): How much to adjust the sharpness. Can be
- any non-negative number. 0 gives a blurred image, 1 gives the
- original image while 2 increases the sharpness by a factor of 2.
- p (float): probability of the image being sharpened. Default value is 0.5
- """
- def __init__(self, sharpness_factor, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.sharpness_factor = sharpness_factor
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be sharpened.
- Returns:
- PIL Image or Tensor: Randomly sharpened image.
- """
- if torch.rand(1).item() < self.p:
- return F.adjust_sharpness(img, self.sharpness_factor)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
- class RandomAutocontrast(torch.nn.Module):
- """Autocontrast the pixels of the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- p (float): probability of the image being autocontrasted. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be autocontrasted.
- Returns:
- PIL Image or Tensor: Randomly autocontrasted image.
- """
- if torch.rand(1).item() < self.p:
- return F.autocontrast(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomEqualize(torch.nn.Module):
- """Equalize the histogram of the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
- Args:
- p (float): probability of the image being equalized. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be equalized.
- Returns:
- PIL Image or Tensor: Randomly equalized image.
- """
- if torch.rand(1).item() < self.p:
- return F.equalize(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class ElasticTransform(torch.nn.Module):
- """Transform a tensor image with elastic transformations.
- Given alpha and sigma, it will generate displacement
- vectors for all pixels based on random offsets. Alpha controls the strength
- and sigma controls the smoothness of the displacements.
- The displacements are added to an identity grid and the resulting grid is
- used to grid_sample from the image.
- Applications:
- Randomly transforms the morphology of objects in images and produces a
- see-through-water-like effect.
- Args:
- alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
- sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- fill (sequence or number): Pixel fill value for the area outside the transformed
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- """
- def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(alpha, (float, Sequence)):
- raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
- if isinstance(alpha, Sequence) and len(alpha) != 2:
- raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
- if isinstance(alpha, Sequence):
- for element in alpha:
- if not isinstance(element, float):
- raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")
- if isinstance(alpha, float):
- alpha = [float(alpha), float(alpha)]
- if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
- alpha = [alpha[0], alpha[0]]
- self.alpha = alpha
- if not isinstance(sigma, (float, Sequence)):
- raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
- if isinstance(sigma, Sequence) and len(sigma) != 2:
- raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
- if isinstance(sigma, Sequence):
- for element in sigma:
- if not isinstance(element, float):
- raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")
- if isinstance(sigma, float):
- sigma = [float(sigma), float(sigma)]
- if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
- sigma = [sigma[0], sigma[0]]
- self.sigma = sigma
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- if isinstance(fill, (int, float)):
- fill = [float(fill)]
- elif isinstance(fill, (list, tuple)):
- fill = [float(f) for f in fill]
- else:
- raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
- self.fill = fill
- @staticmethod
- def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
- dx = torch.rand([1, 1] + size) * 2 - 1
- if sigma[0] > 0.0:
- kx = int(8 * sigma[0] + 1)
- # if kernel size is even we have to make it odd
- if kx % 2 == 0:
- kx += 1
- dx = F.gaussian_blur(dx, [kx, kx], sigma)
- dx = dx * alpha[0] / size[0]
- dy = torch.rand([1, 1] + size) * 2 - 1
- if sigma[1] > 0.0:
- ky = int(8 * sigma[1] + 1)
- # if kernel size is even we have to make it odd
- if ky % 2 == 0:
- ky += 1
- dy = F.gaussian_blur(dy, [ky, ky], sigma)
- dy = dy * alpha[1] / size[1]
- return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
- def forward(self, tensor: Tensor) -> Tensor:
- """
- Args:
- tensor (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: Transformed image.
- """
- _, height, width = F.get_dimensions(tensor)
- displacement = self.get_params(self.alpha, self.sigma, [height, width])
- return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
- def __repr__(self):
- format_string = self.__class__.__name__
- format_string += f"(alpha={self.alpha}"
- format_string += f", sigma={self.sigma}"
- format_string += f", interpolation={self.interpolation}"
- format_string += f", fill={self.fill})"
- return format_string
|