transforms.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import random
  2. import numpy as np
  3. import torch
  4. from torchvision import transforms as T
  5. from torchvision.transforms import functional as F
  6. def pad_if_smaller(img, size, fill=0):
  7. min_size = min(img.size)
  8. if min_size < size:
  9. ow, oh = img.size
  10. padh = size - oh if oh < size else 0
  11. padw = size - ow if ow < size else 0
  12. img = F.pad(img, (0, 0, padw, padh), fill=fill)
  13. return img
  14. class Compose:
  15. def __init__(self, transforms):
  16. self.transforms = transforms
  17. def __call__(self, image, target):
  18. for t in self.transforms:
  19. image, target = t(image, target)
  20. return image, target
  21. class RandomResize:
  22. def __init__(self, min_size, max_size=None):
  23. self.min_size = min_size
  24. if max_size is None:
  25. max_size = min_size
  26. self.max_size = max_size
  27. def __call__(self, image, target):
  28. size = random.randint(self.min_size, self.max_size)
  29. image = F.resize(image, size, antialias=True)
  30. target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
  31. return image, target
  32. class RandomHorizontalFlip:
  33. def __init__(self, flip_prob):
  34. self.flip_prob = flip_prob
  35. def __call__(self, image, target):
  36. if random.random() < self.flip_prob:
  37. image = F.hflip(image)
  38. target = F.hflip(target)
  39. return image, target
  40. class RandomCrop:
  41. def __init__(self, size):
  42. self.size = size
  43. def __call__(self, image, target):
  44. image = pad_if_smaller(image, self.size)
  45. target = pad_if_smaller(target, self.size, fill=255)
  46. crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
  47. image = F.crop(image, *crop_params)
  48. target = F.crop(target, *crop_params)
  49. return image, target
  50. class CenterCrop:
  51. def __init__(self, size):
  52. self.size = size
  53. def __call__(self, image, target):
  54. image = F.center_crop(image, self.size)
  55. target = F.center_crop(target, self.size)
  56. return image, target
  57. class PILToTensor:
  58. def __call__(self, image, target):
  59. image = F.pil_to_tensor(image)
  60. target = torch.as_tensor(np.array(target), dtype=torch.int64)
  61. return image, target
  62. class ToDtype:
  63. def __init__(self, dtype, scale=False):
  64. self.dtype = dtype
  65. self.scale = scale
  66. def __call__(self, image, target):
  67. if not self.scale:
  68. return image.to(dtype=self.dtype), target
  69. image = F.convert_image_dtype(image, self.dtype)
  70. return image, target
  71. class Normalize:
  72. def __init__(self, mean, std):
  73. self.mean = mean
  74. self.std = std
  75. def __call__(self, image, target):
  76. image = F.normalize(image, mean=self.mean, std=self.std)
  77. return image, target