123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- import random
- import numpy as np
- import torch
- from torchvision import transforms as T
- from torchvision.transforms import functional as F
- def pad_if_smaller(img, size, fill=0):
- min_size = min(img.size)
- if min_size < size:
- ow, oh = img.size
- padh = size - oh if oh < size else 0
- padw = size - ow if ow < size else 0
- img = F.pad(img, (0, 0, padw, padh), fill=fill)
- return img
- class Compose:
- def __init__(self, transforms):
- self.transforms = transforms
- def __call__(self, image, target):
- for t in self.transforms:
- image, target = t(image, target)
- return image, target
- class RandomResize:
- def __init__(self, min_size, max_size=None):
- self.min_size = min_size
- if max_size is None:
- max_size = min_size
- self.max_size = max_size
- def __call__(self, image, target):
- size = random.randint(self.min_size, self.max_size)
- image = F.resize(image, size, antialias=True)
- target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
- return image, target
- class RandomHorizontalFlip:
- def __init__(self, flip_prob):
- self.flip_prob = flip_prob
- def __call__(self, image, target):
- if random.random() < self.flip_prob:
- image = F.hflip(image)
- target = F.hflip(target)
- return image, target
- class RandomCrop:
- def __init__(self, size):
- self.size = size
- def __call__(self, image, target):
- image = pad_if_smaller(image, self.size)
- target = pad_if_smaller(target, self.size, fill=255)
- crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
- image = F.crop(image, *crop_params)
- target = F.crop(target, *crop_params)
- return image, target
- class CenterCrop:
- def __init__(self, size):
- self.size = size
- def __call__(self, image, target):
- image = F.center_crop(image, self.size)
- target = F.center_crop(target, self.size)
- return image, target
- class PILToTensor:
- def __call__(self, image, target):
- image = F.pil_to_tensor(image)
- target = torch.as_tensor(np.array(target), dtype=torch.int64)
- return image, target
- class ToDtype:
- def __init__(self, dtype, scale=False):
- self.dtype = dtype
- self.scale = scale
- def __call__(self, image, target):
- if not self.scale:
- return image.to(dtype=self.dtype), target
- image = F.convert_image_dtype(image, self.dtype)
- return image, target
- class Normalize:
- def __init__(self, mean, std):
- self.mean = mean
- self.std = std
- def __call__(self, image, target):
- image = F.normalize(image, mean=self.mean, std=self.std)
- return image, target
|