test_transforms_video.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import random
  2. import warnings
  3. import numpy as np
  4. import pytest
  5. import torch
  6. from common_utils import assert_equal
  7. from torchvision.transforms import Compose
  8. try:
  9. from scipy import stats
  10. except ImportError:
  11. stats = None
  12. with warnings.catch_warnings(record=True):
  13. warnings.simplefilter("always")
  14. import torchvision.transforms._transforms_video as transforms
  15. class TestVideoTransforms:
  16. def test_random_crop_video(self):
  17. numFrames = random.randint(4, 128)
  18. height = random.randint(10, 32) * 2
  19. width = random.randint(10, 32) * 2
  20. oheight = random.randint(5, (height - 2) / 2) * 2
  21. owidth = random.randint(5, (width - 2) / 2) * 2
  22. clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
  23. result = Compose(
  24. [
  25. transforms.ToTensorVideo(),
  26. transforms.RandomCropVideo((oheight, owidth)),
  27. ]
  28. )(clip)
  29. assert result.size(2) == oheight
  30. assert result.size(3) == owidth
  31. transforms.RandomCropVideo((oheight, owidth)).__repr__()
  32. def test_random_resized_crop_video(self):
  33. numFrames = random.randint(4, 128)
  34. height = random.randint(10, 32) * 2
  35. width = random.randint(10, 32) * 2
  36. oheight = random.randint(5, (height - 2) / 2) * 2
  37. owidth = random.randint(5, (width - 2) / 2) * 2
  38. clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
  39. result = Compose(
  40. [
  41. transforms.ToTensorVideo(),
  42. transforms.RandomResizedCropVideo((oheight, owidth)),
  43. ]
  44. )(clip)
  45. assert result.size(2) == oheight
  46. assert result.size(3) == owidth
  47. transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
  48. def test_center_crop_video(self):
  49. numFrames = random.randint(4, 128)
  50. height = random.randint(10, 32) * 2
  51. width = random.randint(10, 32) * 2
  52. oheight = random.randint(5, (height - 2) / 2) * 2
  53. owidth = random.randint(5, (width - 2) / 2) * 2
  54. clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
  55. oh1 = (height - oheight) // 2
  56. ow1 = (width - owidth) // 2
  57. clipNarrow = clip[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth, :]
  58. clipNarrow.fill_(0)
  59. result = Compose(
  60. [
  61. transforms.ToTensorVideo(),
  62. transforms.CenterCropVideo((oheight, owidth)),
  63. ]
  64. )(clip)
  65. msg = (
  66. "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
  67. )
  68. assert result.sum().item() == 0, msg
  69. oheight += 1
  70. owidth += 1
  71. result = Compose(
  72. [
  73. transforms.ToTensorVideo(),
  74. transforms.CenterCropVideo((oheight, owidth)),
  75. ]
  76. )(clip)
  77. sum1 = result.sum()
  78. msg = (
  79. "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
  80. )
  81. assert sum1.item() > 1, msg
  82. oheight += 1
  83. owidth += 1
  84. result = Compose(
  85. [
  86. transforms.ToTensorVideo(),
  87. transforms.CenterCropVideo((oheight, owidth)),
  88. ]
  89. )(clip)
  90. sum2 = result.sum()
  91. msg = (
  92. "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
  93. )
  94. assert sum2.item() > 1, msg
  95. assert sum2.item() > sum1.item(), msg
  96. @pytest.mark.skipif(stats is None, reason="scipy.stats is not available")
  97. @pytest.mark.parametrize("channels", [1, 3])
  98. def test_normalize_video(self, channels):
  99. def samples_from_standard_normal(tensor):
  100. p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue
  101. return p_value > 0.0001
  102. random_state = random.getstate()
  103. random.seed(42)
  104. numFrames = random.randint(4, 128)
  105. height = random.randint(32, 256)
  106. width = random.randint(32, 256)
  107. mean = random.random()
  108. std = random.random()
  109. clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
  110. mean = [clip[c].mean().item() for c in range(channels)]
  111. std = [clip[c].std().item() for c in range(channels)]
  112. normalized = transforms.NormalizeVideo(mean, std)(clip)
  113. assert samples_from_standard_normal(normalized)
  114. random.setstate(random_state)
  115. # Checking the optional in-place behaviour
  116. tensor = torch.rand((3, 128, 16, 16))
  117. tensor_inplace = transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(tensor)
  118. assert_equal(tensor, tensor_inplace)
  119. transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).__repr__()
  120. def test_to_tensor_video(self):
  121. numFrames, height, width = 64, 4, 4
  122. trans = transforms.ToTensorVideo()
  123. with pytest.raises(TypeError):
  124. np_rng = np.random.RandomState(0)
  125. trans(np_rng.rand(numFrames, height, width, 1).tolist())
  126. with pytest.raises(TypeError):
  127. trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
  128. with pytest.raises(ValueError):
  129. trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
  130. with pytest.raises(ValueError):
  131. trans(torch.ones((height, width, 3), dtype=torch.uint8))
  132. with pytest.raises(ValueError):
  133. trans(torch.ones((width, 3), dtype=torch.uint8))
  134. with pytest.raises(ValueError):
  135. trans(torch.ones((3), dtype=torch.uint8))
  136. trans.__repr__()
  137. @pytest.mark.parametrize("p", (0, 1))
  138. def test_random_horizontal_flip_video(self, p):
  139. clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
  140. hclip = clip.flip(-1)
  141. out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
  142. if p == 0:
  143. torch.testing.assert_close(out, clip)
  144. elif p == 1:
  145. torch.testing.assert_close(out, hclip)
  146. transforms.RandomHorizontalFlipVideo().__repr__()
  147. if __name__ == "__main__":
  148. pytest.main([__file__])