_functional_video.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import warnings
  2. import torch
  3. warnings.warn(
  4. "The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. "
  5. "Please use the 'torchvision.transforms.functional' module instead."
  6. )
  7. def _is_tensor_video_clip(clip):
  8. if not torch.is_tensor(clip):
  9. raise TypeError("clip should be Tensor. Got %s" % type(clip))
  10. if not clip.ndimension() == 4:
  11. raise ValueError("clip should be 4D. Got %dD" % clip.dim())
  12. return True
  13. def crop(clip, i, j, h, w):
  14. """
  15. Args:
  16. clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
  17. """
  18. if len(clip.size()) != 4:
  19. raise ValueError("clip should be a 4D tensor")
  20. return clip[..., i : i + h, j : j + w]
  21. def resize(clip, target_size, interpolation_mode):
  22. if len(target_size) != 2:
  23. raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
  24. return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
  25. def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
  26. """
  27. Do spatial cropping and resizing to the video clip
  28. Args:
  29. clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
  30. i (int): i in (i,j) i.e coordinates of the upper left corner.
  31. j (int): j in (i,j) i.e coordinates of the upper left corner.
  32. h (int): Height of the cropped region.
  33. w (int): Width of the cropped region.
  34. size (tuple(int, int)): height and width of resized clip
  35. Returns:
  36. clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
  37. """
  38. if not _is_tensor_video_clip(clip):
  39. raise ValueError("clip should be a 4D torch.tensor")
  40. clip = crop(clip, i, j, h, w)
  41. clip = resize(clip, size, interpolation_mode)
  42. return clip
  43. def center_crop(clip, crop_size):
  44. if not _is_tensor_video_clip(clip):
  45. raise ValueError("clip should be a 4D torch.tensor")
  46. h, w = clip.size(-2), clip.size(-1)
  47. th, tw = crop_size
  48. if h < th or w < tw:
  49. raise ValueError("height and width must be no smaller than crop_size")
  50. i = int(round((h - th) / 2.0))
  51. j = int(round((w - tw) / 2.0))
  52. return crop(clip, i, j, th, tw)
  53. def to_tensor(clip):
  54. """
  55. Convert tensor data type from uint8 to float, divide value by 255.0 and
  56. permute the dimensions of clip tensor
  57. Args:
  58. clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
  59. Return:
  60. clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
  61. """
  62. _is_tensor_video_clip(clip)
  63. if not clip.dtype == torch.uint8:
  64. raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
  65. return clip.float().permute(3, 0, 1, 2) / 255.0
  66. def normalize(clip, mean, std, inplace=False):
  67. """
  68. Args:
  69. clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
  70. mean (tuple): pixel RGB mean. Size is (3)
  71. std (tuple): pixel standard deviation. Size is (3)
  72. Returns:
  73. normalized clip (torch.tensor): Size is (C, T, H, W)
  74. """
  75. if not _is_tensor_video_clip(clip):
  76. raise ValueError("clip should be a 4D torch.tensor")
  77. if not inplace:
  78. clip = clip.clone()
  79. mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
  80. std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
  81. clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
  82. return clip
  83. def hflip(clip):
  84. """
  85. Args:
  86. clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
  87. Returns:
  88. flipped clip (torch.tensor): Size is (C, T, H, W)
  89. """
  90. if not _is_tensor_video_clip(clip):
  91. raise ValueError("clip should be a 4D torch.tensor")
  92. return clip.flip(-1)