test_video_gpu_decoder.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import math
  2. import os
  3. import pytest
  4. import torch
  5. import torchvision
  6. from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
  7. try:
  8. import av
  9. except ImportError:
  10. av = None
  11. VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
  12. @pytest.mark.skipif(_HAS_GPU_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
  13. class TestVideoGPUDecoder:
  14. @pytest.mark.skipif(av is None, reason="PyAV unavailable")
  15. @pytest.mark.parametrize(
  16. "video_file",
  17. [
  18. "RATRACE_wave_f_nm_np1_fr_goo_37.avi",
  19. "TrumanShow_wave_f_nm_np1_fr_med_26.avi",
  20. "v_SoccerJuggling_g23_c01.avi",
  21. "v_SoccerJuggling_g24_c01.avi",
  22. "R6llTwEh07w.mp4",
  23. "SOX5yA1l24A.mp4",
  24. "WUzgd7C1pWA.mp4",
  25. ],
  26. )
  27. def test_frame_reading(self, video_file):
  28. torchvision.set_video_backend("cuda")
  29. full_path = os.path.join(VIDEO_DIR, video_file)
  30. decoder = VideoReader(full_path)
  31. with av.open(full_path) as container:
  32. for av_frame in container.decode(container.streams.video[0]):
  33. av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
  34. vision_frames = next(decoder)["data"]
  35. mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
  36. assert mean_delta < 0.75
  37. @pytest.mark.skipif(av is None, reason="PyAV unavailable")
  38. @pytest.mark.parametrize("keyframes", [True, False])
  39. @pytest.mark.parametrize(
  40. "full_path, duration",
  41. [
  42. (os.path.join(VIDEO_DIR, x), y)
  43. for x, y in [
  44. ("v_SoccerJuggling_g23_c01.avi", 8.0),
  45. ("v_SoccerJuggling_g24_c01.avi", 8.0),
  46. ("R6llTwEh07w.mp4", 10.0),
  47. ("SOX5yA1l24A.mp4", 11.0),
  48. ("WUzgd7C1pWA.mp4", 11.0),
  49. ]
  50. ],
  51. )
  52. def test_seek_reading(self, keyframes, full_path, duration):
  53. torchvision.set_video_backend("cuda")
  54. decoder = VideoReader(full_path)
  55. time = duration / 2
  56. decoder.seek(time, keyframes_only=keyframes)
  57. with av.open(full_path) as container:
  58. container.seek(int(time * 1000000), any_frame=not keyframes, backward=False)
  59. for av_frame in container.decode(container.streams.video[0]):
  60. av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
  61. vision_frames = next(decoder)["data"]
  62. mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
  63. assert mean_delta < 0.75
  64. @pytest.mark.skipif(av is None, reason="PyAV unavailable")
  65. @pytest.mark.parametrize(
  66. "video_file",
  67. [
  68. "RATRACE_wave_f_nm_np1_fr_goo_37.avi",
  69. "TrumanShow_wave_f_nm_np1_fr_med_26.avi",
  70. "v_SoccerJuggling_g23_c01.avi",
  71. "v_SoccerJuggling_g24_c01.avi",
  72. "R6llTwEh07w.mp4",
  73. "SOX5yA1l24A.mp4",
  74. "WUzgd7C1pWA.mp4",
  75. ],
  76. )
  77. def test_metadata(self, video_file):
  78. torchvision.set_video_backend("cuda")
  79. full_path = os.path.join(VIDEO_DIR, video_file)
  80. decoder = VideoReader(full_path)
  81. video_metadata = decoder.get_metadata()["video"]
  82. with av.open(full_path) as container:
  83. video = container.streams.video[0]
  84. av_duration = float(video.duration * video.time_base)
  85. assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2)
  86. assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2)
  87. if __name__ == "__main__":
  88. pytest.main([__file__])