test_videoapi.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import collections
  2. import os
  3. import urllib
  4. import pytest
  5. import torch
  6. import torchvision
  7. from pytest import approx
  8. from torchvision.datasets.utils import download_url
  9. from torchvision.io import _HAS_VIDEO_OPT, VideoReader
  10. # WARNING: these tests have been skipped forever on the CI because the video ops
  11. # are never properly available. This is bad, but things have been in a terrible
  12. # state for a long time already as we write this comment, and we'll hopefully be
  13. # able to get rid of this all soon.
  14. try:
  15. import av
  16. # Do a version test too
  17. torchvision.io.video._check_av_available()
  18. except ImportError:
  19. av = None
  20. VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
  21. CheckerConfig = ["duration", "video_fps", "audio_sample_rate"]
  22. GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
  23. def backends():
  24. backends_ = ["video_reader"]
  25. if av is not None:
  26. backends_.append("pyav")
  27. return backends_
  28. def fate(name, path="."):
  29. """Download and return a path to a sample from the FFmpeg test suite.
  30. See the `FFmpeg Automated Test Environment <https://www.ffmpeg.org/fate.html>`_
  31. """
  32. file_name = name.split("/")[1]
  33. download_url("http://fate.ffmpeg.org/fate-suite/" + name, path, file_name)
  34. return os.path.join(path, file_name)
  35. test_videos = {
  36. "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
  37. "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
  38. duration=2.0, video_fps=30.0, audio_sample_rate=None
  39. ),
  40. "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None),
  41. "v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
  42. "v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None),
  43. "R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100),
  44. "SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
  45. "WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000),
  46. }
  47. @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
  48. class TestVideoApi:
  49. @pytest.mark.skipif(av is None, reason="PyAV unavailable")
  50. @pytest.mark.parametrize("test_video", test_videos.keys())
  51. @pytest.mark.parametrize("backend", backends())
  52. def test_frame_reading(self, test_video, backend):
  53. torchvision.set_video_backend(backend)
  54. full_path = os.path.join(VIDEO_DIR, test_video)
  55. with av.open(full_path) as av_reader:
  56. if av_reader.streams.video:
  57. av_frames, vr_frames = [], []
  58. av_pts, vr_pts = [], []
  59. # get av frames
  60. for av_frame in av_reader.decode(av_reader.streams.video[0]):
  61. av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
  62. av_pts.append(av_frame.pts * av_frame.time_base)
  63. # get vr frames
  64. video_reader = VideoReader(full_path, "video")
  65. for vr_frame in video_reader:
  66. vr_frames.append(vr_frame["data"])
  67. vr_pts.append(vr_frame["pts"])
  68. # same number of frames
  69. assert len(vr_frames) == len(av_frames)
  70. assert len(vr_pts) == len(av_pts)
  71. # compare the frames and ptss
  72. for i in range(len(vr_frames)):
  73. assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
  74. mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
  75. # on average the difference is very small and caused
  76. # by decoding (around 1%)
  77. # TODO: asses empirically how to set this? atm it's 1%
  78. # averaged over all frames
  79. assert mean_delta.item() < 2.55
  80. del vr_frames, av_frames, vr_pts, av_pts
  81. # test audio reading compared to PYAV
  82. with av.open(full_path) as av_reader:
  83. if av_reader.streams.audio:
  84. av_frames, vr_frames = [], []
  85. av_pts, vr_pts = [], []
  86. # get av frames
  87. for av_frame in av_reader.decode(av_reader.streams.audio[0]):
  88. av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
  89. av_pts.append(av_frame.pts * av_frame.time_base)
  90. av_reader.close()
  91. # get vr frames
  92. video_reader = VideoReader(full_path, "audio")
  93. for vr_frame in video_reader:
  94. vr_frames.append(vr_frame["data"])
  95. vr_pts.append(vr_frame["pts"])
  96. # same number of frames
  97. assert len(vr_frames) == len(av_frames)
  98. assert len(vr_pts) == len(av_pts)
  99. # compare the frames and ptss
  100. for i in range(len(vr_frames)):
  101. assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
  102. max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
  103. # we assure that there is never more than 1% difference in signal
  104. assert max_delta.item() < 0.001
  105. @pytest.mark.parametrize("stream", ["video", "audio"])
  106. @pytest.mark.parametrize("test_video", test_videos.keys())
  107. @pytest.mark.parametrize("backend", backends())
  108. def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
  109. torchvision.set_video_backend(backend)
  110. full_path = os.path.join(VIDEO_DIR, test_video)
  111. reader = VideoReader(full_path)
  112. reader_md = reader.get_metadata()
  113. if stream in reader_md:
  114. # Test video reading from file vs from memory
  115. vr_frames, vr_frames_mem = [], []
  116. vr_pts, vr_pts_mem = [], []
  117. # get vr frames
  118. video_reader = VideoReader(full_path, stream)
  119. for vr_frame in video_reader:
  120. vr_frames.append(vr_frame["data"])
  121. vr_pts.append(vr_frame["pts"])
  122. # get vr frames = read from memory
  123. f = open(full_path, "rb")
  124. fbytes = f.read()
  125. f.close()
  126. video_reader_from_mem = VideoReader(fbytes, stream)
  127. for vr_frame_from_mem in video_reader_from_mem:
  128. vr_frames_mem.append(vr_frame_from_mem["data"])
  129. vr_pts_mem.append(vr_frame_from_mem["pts"])
  130. # same number of frames
  131. assert len(vr_frames) == len(vr_frames_mem)
  132. assert len(vr_pts) == len(vr_pts_mem)
  133. # compare the frames and ptss
  134. for i in range(len(vr_frames)):
  135. assert vr_pts[i] == vr_pts_mem[i]
  136. mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
  137. # on average the difference is very small and caused
  138. # by decoding (around 1%)
  139. # TODO: asses empirically how to set this? atm it's 1%
  140. # averaged over all frames
  141. assert mean_delta.item() < 2.55
  142. del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
  143. else:
  144. del reader, reader_md
  145. @pytest.mark.parametrize("test_video,config", test_videos.items())
  146. @pytest.mark.parametrize("backend", backends())
  147. def test_metadata(self, test_video, config, backend):
  148. """
  149. Test that the metadata returned via pyav corresponds to the one returned
  150. by the new video decoder API
  151. """
  152. torchvision.set_video_backend(backend)
  153. full_path = os.path.join(VIDEO_DIR, test_video)
  154. reader = VideoReader(full_path, "video")
  155. reader_md = reader.get_metadata()
  156. assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
  157. assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
  158. @pytest.mark.parametrize("test_video", test_videos.keys())
  159. @pytest.mark.parametrize("backend", backends())
  160. def test_seek_start(self, test_video, backend):
  161. torchvision.set_video_backend(backend)
  162. full_path = os.path.join(VIDEO_DIR, test_video)
  163. video_reader = VideoReader(full_path, "video")
  164. num_frames = 0
  165. for _ in video_reader:
  166. num_frames += 1
  167. # now seek the container to 0 and do it again
  168. # It's often that starting seek can be inprecise
  169. # this way and it doesn't start at 0
  170. video_reader.seek(0)
  171. start_num_frames = 0
  172. for _ in video_reader:
  173. start_num_frames += 1
  174. assert start_num_frames == num_frames
  175. # now seek the container to < 0 to check for unexpected behaviour
  176. video_reader.seek(-1)
  177. start_num_frames = 0
  178. for _ in video_reader:
  179. start_num_frames += 1
  180. assert start_num_frames == num_frames
  181. @pytest.mark.parametrize("test_video", test_videos.keys())
  182. @pytest.mark.parametrize("backend", ["video_reader"])
  183. def test_accurateseek_middle(self, test_video, backend):
  184. torchvision.set_video_backend(backend)
  185. full_path = os.path.join(VIDEO_DIR, test_video)
  186. stream = "video"
  187. video_reader = VideoReader(full_path, stream)
  188. md = video_reader.get_metadata()
  189. duration = md[stream]["duration"][0]
  190. if duration is not None:
  191. num_frames = 0
  192. for _ in video_reader:
  193. num_frames += 1
  194. video_reader.seek(duration / 2)
  195. middle_num_frames = 0
  196. for _ in video_reader:
  197. middle_num_frames += 1
  198. assert middle_num_frames < num_frames
  199. assert middle_num_frames == approx(num_frames // 2, abs=1)
  200. video_reader.seek(duration / 2)
  201. frame = next(video_reader)
  202. lb = duration / 2 - 1 / md[stream]["fps"][0]
  203. ub = duration / 2 + 1 / md[stream]["fps"][0]
  204. assert (lb <= frame["pts"]) and (ub >= frame["pts"])
  205. def test_fate_suite(self):
  206. # TODO: remove the try-except statement once the connectivity issues are resolved
  207. try:
  208. video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
  209. except (urllib.error.URLError, ConnectionError) as error:
  210. pytest.skip(f"Skipping due to connectivity issues: {error}")
  211. vr = VideoReader(video_path)
  212. metadata = vr.get_metadata()
  213. assert metadata["subtitles"]["duration"] is not None
  214. os.remove(video_path)
  215. @pytest.mark.skipif(av is None, reason="PyAV unavailable")
  216. @pytest.mark.parametrize("test_video,config", test_videos.items())
  217. @pytest.mark.parametrize("backend", backends())
  218. def test_keyframe_reading(self, test_video, config, backend):
  219. torchvision.set_video_backend(backend)
  220. full_path = os.path.join(VIDEO_DIR, test_video)
  221. av_reader = av.open(full_path)
  222. # reduce streams to only keyframes
  223. av_stream = av_reader.streams.video[0]
  224. av_stream.codec_context.skip_frame = "NONKEY"
  225. av_keyframes = []
  226. vr_keyframes = []
  227. if av_reader.streams.video:
  228. # get all keyframes using pyav. Then, seek randomly into video reader
  229. # and assert that all the returned values are in AV_KEYFRAMES
  230. for av_frame in av_reader.decode(av_stream):
  231. av_keyframes.append(float(av_frame.pts * av_frame.time_base))
  232. if len(av_keyframes) > 1:
  233. video_reader = VideoReader(full_path, "video")
  234. for i in range(1, len(av_keyframes)):
  235. seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
  236. data = next(video_reader.seek(seek_val, True))
  237. vr_keyframes.append(data["pts"])
  238. data = next(video_reader.seek(config.duration, True))
  239. vr_keyframes.append(data["pts"])
  240. assert len(av_keyframes) == len(vr_keyframes)
  241. # NOTE: this video gets different keyframe with different
  242. # loaders (0.333 pyav, 0.666 for us)
  243. if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
  244. for i in range(len(av_keyframes)):
  245. assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
  246. if __name__ == "__main__":
  247. pytest.main([__file__])