_transforms_video.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #!/usr/bin/env python3
  2. import numbers
  3. import random
  4. import warnings
  5. from torchvision.transforms import RandomCrop, RandomResizedCrop
  6. from . import _functional_video as F
  7. __all__ = [
  8. "RandomCropVideo",
  9. "RandomResizedCropVideo",
  10. "CenterCropVideo",
  11. "NormalizeVideo",
  12. "ToTensorVideo",
  13. "RandomHorizontalFlipVideo",
  14. ]
  15. warnings.warn(
  16. "The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in the future. "
  17. "Please use the 'torchvision.transforms' module instead."
  18. )
  19. class RandomCropVideo(RandomCrop):
  20. def __init__(self, size):
  21. if isinstance(size, numbers.Number):
  22. self.size = (int(size), int(size))
  23. else:
  24. self.size = size
  25. def __call__(self, clip):
  26. """
  27. Args:
  28. clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
  29. Returns:
  30. torch.tensor: randomly cropped/resized video clip.
  31. size is (C, T, OH, OW)
  32. """
  33. i, j, h, w = self.get_params(clip, self.size)
  34. return F.crop(clip, i, j, h, w)
  35. def __repr__(self) -> str:
  36. return f"{self.__class__.__name__}(size={self.size})"
  37. class RandomResizedCropVideo(RandomResizedCrop):
  38. def __init__(
  39. self,
  40. size,
  41. scale=(0.08, 1.0),
  42. ratio=(3.0 / 4.0, 4.0 / 3.0),
  43. interpolation_mode="bilinear",
  44. ):
  45. if isinstance(size, tuple):
  46. if len(size) != 2:
  47. raise ValueError(f"size should be tuple (height, width), instead got {size}")
  48. self.size = size
  49. else:
  50. self.size = (size, size)
  51. self.interpolation_mode = interpolation_mode
  52. self.scale = scale
  53. self.ratio = ratio
  54. def __call__(self, clip):
  55. """
  56. Args:
  57. clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
  58. Returns:
  59. torch.tensor: randomly cropped/resized video clip.
  60. size is (C, T, H, W)
  61. """
  62. i, j, h, w = self.get_params(clip, self.scale, self.ratio)
  63. return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
  64. def __repr__(self) -> str:
  65. return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
  66. class CenterCropVideo:
  67. def __init__(self, crop_size):
  68. if isinstance(crop_size, numbers.Number):
  69. self.crop_size = (int(crop_size), int(crop_size))
  70. else:
  71. self.crop_size = crop_size
  72. def __call__(self, clip):
  73. """
  74. Args:
  75. clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
  76. Returns:
  77. torch.tensor: central cropping of video clip. Size is
  78. (C, T, crop_size, crop_size)
  79. """
  80. return F.center_crop(clip, self.crop_size)
  81. def __repr__(self) -> str:
  82. return f"{self.__class__.__name__}(crop_size={self.crop_size})"
  83. class NormalizeVideo:
  84. """
  85. Normalize the video clip by mean subtraction and division by standard deviation
  86. Args:
  87. mean (3-tuple): pixel RGB mean
  88. std (3-tuple): pixel RGB standard deviation
  89. inplace (boolean): whether do in-place normalization
  90. """
  91. def __init__(self, mean, std, inplace=False):
  92. self.mean = mean
  93. self.std = std
  94. self.inplace = inplace
  95. def __call__(self, clip):
  96. """
  97. Args:
  98. clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
  99. """
  100. return F.normalize(clip, self.mean, self.std, self.inplace)
  101. def __repr__(self) -> str:
  102. return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
  103. class ToTensorVideo:
  104. """
  105. Convert tensor data type from uint8 to float, divide value by 255.0 and
  106. permute the dimensions of clip tensor
  107. """
  108. def __init__(self):
  109. pass
  110. def __call__(self, clip):
  111. """
  112. Args:
  113. clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
  114. Return:
  115. clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
  116. """
  117. return F.to_tensor(clip)
  118. def __repr__(self) -> str:
  119. return self.__class__.__name__
  120. class RandomHorizontalFlipVideo:
  121. """
  122. Flip the video clip along the horizontal direction with a given probability
  123. Args:
  124. p (float): probability of the clip being flipped. Default value is 0.5
  125. """
  126. def __init__(self, p=0.5):
  127. self.p = p
  128. def __call__(self, clip):
  129. """
  130. Args:
  131. clip (torch.tensor): Size is (C, T, H, W)
  132. Return:
  133. clip (torch.tensor): Size is (C, T, H, W)
  134. """
  135. if random.random() < self.p:
  136. clip = F.hflip(clip)
  137. return clip
  138. def __repr__(self) -> str:
  139. return f"{self.__class__.__name__}(p={self.p})"