import collections import math import os from fractions import Fraction import numpy as np import pytest import torch import torchvision.io as io from common_utils import assert_equal from numpy.random import randint from pytest import approx from torchvision import set_video_backend from torchvision.io import _HAS_VIDEO_OPT try: import av # Do a version test too io.video._check_av_available() except ImportError: av = None VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") CheckerConfig = [ "duration", "video_fps", "audio_sample_rate", # We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are # slightly different between TorchVision decoder and PyAv decoder. So omit it during check "check_aframes", "check_aframe_pts", ] GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig)) all_check_config = GroundTruth( duration=0, video_fps=0, audio_sample_rate=0, check_aframes=True, check_aframe_pts=True, ) test_videos = { "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( duration=2.0, video_fps=30.0, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( duration=2.0, video_fps=30.0, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( duration=2.0, video_fps=30.0, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "v_SoccerJuggling_g23_c01.avi": GroundTruth( duration=8.0, video_fps=29.97, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "v_SoccerJuggling_g24_c01.avi": GroundTruth( duration=8.0, video_fps=29.97, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "R6llTwEh07w.mp4": GroundTruth( duration=10.0, video_fps=30.0, audio_sample_rate=44100, # PyAv miss one audio frame at the beginning (pts=0) check_aframes=False, check_aframe_pts=False, ), "SOX5yA1l24A.mp4": GroundTruth( duration=11.0, video_fps=29.97, audio_sample_rate=48000, # PyAv miss one audio frame at the beginning (pts=0) check_aframes=False, check_aframe_pts=False, ), "WUzgd7C1pWA.mp4": GroundTruth( duration=11.0, video_fps=29.97, audio_sample_rate=48000, # PyAv miss one audio frame at the beginning (pts=0) check_aframes=False, check_aframe_pts=False, ), } DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase") # av_seek_frame is imprecise so seek to a timestamp earlier by a margin # The unit of margin is second SEEK_FRAME_MARGIN = 0.25 def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4): """ Args: container: pyav container start_pts/end_pts: the starting/ending Presentation TimeStamp where frames are read stream: pyav stream stream_name: a dictionary of streams. For example, {"video": 0} means video stream at stream index 0 buffer_size: pts of frames decoded by PyAv is not guaranteed to be in ascending order. We need to decode more frames even when we meet end pts """ # seeking in the stream is imprecise. Thus, seek to an earlier PTS by a margin margin = 1 seek_offset = max(start_pts - margin, 0) container.seek(seek_offset, any_frame=False, backward=True, stream=stream) frames = {} buffer_count = 0 for frame in container.decode(**stream_name): if frame.pts < start_pts: continue if frame.pts <= end_pts: frames[frame.pts] = frame else: buffer_count += 1 if buffer_count >= buffer_size: break result = [frames[pts] for pts in sorted(frames)] return result def _get_timebase_by_av_module(full_path): container = av.open(full_path) video_time_base = container.streams.video[0].time_base if container.streams.audio: audio_time_base = container.streams.audio[0].time_base else: audio_time_base = None return video_time_base, audio_time_base def _fraction_to_tensor(fraction): ret = torch.zeros([2], dtype=torch.int32) ret[0] = fraction.numerator ret[1] = fraction.denominator return ret def _decode_frames_by_av_module( full_path, video_start_pts=0, video_end_pts=None, audio_start_pts=0, audio_end_pts=None, ): """ Use PyAv to decode video frames. This provides a reference for our decoder to compare the decoding results. Input arguments: full_path: video file path video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where frames are read """ if video_end_pts is None: video_end_pts = float("inf") if audio_end_pts is None: audio_end_pts = float("inf") container = av.open(full_path) video_frames = [] vtimebase = torch.zeros([0], dtype=torch.int32) if container.streams.video: video_frames = _read_from_stream( container, video_start_pts, video_end_pts, container.streams.video[0], {"video": 0}, ) # container.streams.video[0].average_rate is not a reliable estimator of # frame rate. It can be wrong for certain codec, such as VP80 # So we do not return video fps here vtimebase = _fraction_to_tensor(container.streams.video[0].time_base) audio_frames = [] atimebase = torch.zeros([0], dtype=torch.int32) if container.streams.audio: audio_frames = _read_from_stream( container, audio_start_pts, audio_end_pts, container.streams.audio[0], {"audio": 0}, ) atimebase = _fraction_to_tensor(container.streams.audio[0].time_base) container.close() vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] vframes = torch.as_tensor(np.stack(vframes)) vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64) aframes = [frame.to_ndarray() for frame in audio_frames] if aframes: aframes = np.transpose(np.concatenate(aframes, axis=1)) aframes = torch.as_tensor(aframes) else: aframes = torch.empty((1, 0), dtype=torch.float32) aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64) return DecoderResult( vframes=vframes, vframe_pts=vframe_pts, vtimebase=vtimebase, aframes=aframes, aframe_pts=aframe_pts, atimebase=atimebase, ) def _pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): """convert pts between different time bases Args: pts: presentation timestamp, float timebase_from: original timebase. Fraction timebase_to: new timebase. Fraction round_func: rounding function. """ new_pts = Fraction(pts, 1) * timebase_from / timebase_to return int(round_func(new_pts)) def _get_video_tensor(video_dir, video_file): """open a video file, and represent the video data by a PT tensor""" full_path = os.path.join(video_dir, video_file) assert os.path.exists(full_path), "File not found: %s" % full_path with open(full_path, "rb") as fp: video_tensor = torch.frombuffer(fp.read(), dtype=torch.uint8) return full_path, video_tensor @pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg") class TestVideoReader: def check_separate_decoding_result(self, tv_result, config): """check the decoding results from TorchVision decoder""" ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration, ) = tv_result video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item()) assert video_duration == approx(config.duration, abs=0.5) assert vfps.item() == approx(config.video_fps, abs=0.5) if asample_rate.numel() > 0: assert asample_rate.item() == config.audio_sample_rate audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item()) assert audio_duration == approx(config.duration, abs=0.5) # check if pts of video frames are sorted in ascending order for i in range(len(vframe_pts) - 1): assert vframe_pts[i] < vframe_pts[i + 1] if len(aframe_pts) > 1: # check if pts of audio frames are sorted in ascending order for i in range(len(aframe_pts) - 1): assert aframe_pts[i] < aframe_pts[i + 1] def check_probe_result(self, result, config): vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item()) assert video_duration == approx(config.duration, abs=0.5) assert vfps.item() == approx(config.video_fps, abs=0.5) if asample_rate.numel() > 0: assert asample_rate.item() == config.audio_sample_rate audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item()) assert audio_duration == approx(config.duration, abs=0.5) def check_meta_result(self, result, config): assert result.video_duration == approx(config.duration, abs=0.5) assert result.video_fps == approx(config.video_fps, abs=0.5) if result.has_audio > 0: assert result.audio_sample_rate == config.audio_sample_rate assert result.audio_duration == approx(config.duration, abs=0.5) def compare_decoding_result(self, tv_result, ref_result, config=all_check_config): """ Compare decoding results from two sources. Args: tv_result: decoding results from TorchVision decoder ref_result: reference decoding results which can be from either PyAv decoder or TorchVision decoder with getPtsOnly = 1 config: config of decoding results checker """ ( vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, atimebase, _asample_rate, _aduration, ) = tv_result if isinstance(ref_result, list): # the ref_result is from new video_reader decoder ref_result = DecoderResult( vframes=ref_result[0], vframe_pts=ref_result[1], vtimebase=ref_result[2], aframes=ref_result[5], aframe_pts=ref_result[6], atimebase=ref_result[7], ) if vframes.numel() > 0 and ref_result.vframes.numel() > 0: mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float())) assert mean_delta == approx(0.0, abs=8.0) mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())) assert mean_delta == approx(0.0, abs=1.0) assert_equal(vtimebase, ref_result.vtimebase) if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0: """Audio stream is available and audio frame is required to return from decoder""" assert_equal(aframes, ref_result.aframes) if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0: """Audio stream is available""" assert_equal(aframe_pts, ref_result.aframe_pts) assert_equal(atimebase, ref_result.atimebase) @pytest.mark.parametrize("test_video", test_videos.keys()) def test_stress_test_read_video_from_file(self, test_video): pytest.skip( "This stress test will iteratively decode the same set of videos." "It helps to detect memory leak but it takes lots of time to run." "By default, it is disabled" ) num_iter = 10000 # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 for _i in range(num_iter): full_path = os.path.join(VIDEO_DIR, test_video) # pass 1: decode all frames using new decoder torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) @pytest.mark.parametrize("test_video,config", test_videos.items()) def test_read_video_from_file(self, test_video, config): """ Test the case when decoder starts with a video file to decode frames. """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) # pass 1: decode all frames using new decoder tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) # pass 2: decode all frames using av pyav_result = _decode_frames_by_av_module(full_path) # check results from TorchVision decoder self.check_separate_decoding_result(tv_result, config) # compare decoding results self.compare_decoding_result(tv_result, pyav_result, config) @pytest.mark.parametrize("test_video,config", test_videos.items()) @pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)]) def test_read_video_from_file_read_single_stream_only( self, test_video, config, read_video_stream, read_audio_stream ): """ Test the case when decoder starts with a video file to decode frames, and only reads video stream and ignores audio stream """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) # decode all frames using new decoder tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly read_video_stream, width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, read_audio_stream, samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration, ) = tv_result assert (vframes.numel() > 0) is bool(read_video_stream) assert (vframe_pts.numel() > 0) is bool(read_video_stream) assert (vtimebase.numel() > 0) is bool(read_video_stream) assert (vfps.numel() > 0) is bool(read_video_stream) expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None assert (aframes.numel() > 0) is bool(expect_audio_data) assert (aframe_pts.numel() > 0) is bool(expect_audio_data) assert (atimebase.numel() > 0) is bool(expect_audio_data) assert (asample_rate.numel() > 0) is bool(expect_audio_data) @pytest.mark.parametrize("test_video", test_videos.keys()) def test_read_video_from_file_rescale_min_dimension(self, test_video): """ Test the case when decoder starts with a video file to decode frames, and video min dimension between height and width is set. """ # video related width, height, min_dimension, max_dimension = 0, 0, 128, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2)) @pytest.mark.parametrize("test_video", test_videos.keys()) def test_read_video_from_file_rescale_max_dimension(self, test_video): """ Test the case when decoder starts with a video file to decode frames, and video min dimension between height and width is set. """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 85 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2)) @pytest.mark.parametrize("test_video", test_videos.keys()) def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video): """ Test the case when decoder starts with a video file to decode frames, and video min dimension between height and width is set. """ # video related width, height, min_dimension, max_dimension = 0, 0, 64, 85 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2)) assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2)) @pytest.mark.parametrize("test_video", test_videos.keys()) def test_read_video_from_file_rescale_width(self, test_video): """ Test the case when decoder starts with a video file to decode frames, and video width is set. """ # video related width, height, min_dimension, max_dimension = 256, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert tv_result[0].size(2) == width @pytest.mark.parametrize("test_video", test_videos.keys()) def test_read_video_from_file_rescale_height(self, test_video): """ Test the case when decoder starts with a video file to decode frames, and video height is set. """ # video related width, height, min_dimension, max_dimension = 0, 224, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert tv_result[0].size(1) == height @pytest.mark.parametrize("test_video", test_videos.keys()) def test_read_video_from_file_rescale_width_and_height(self, test_video): """ Test the case when decoder starts with a video file to decode frames, and both video height and width are set. """ # video related width, height, min_dimension, max_dimension = 320, 240, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert tv_result[0].size(1) == height assert tv_result[0].size(2) == width @pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("samples", [9600, 96000]) def test_read_video_from_file_audio_resampling(self, test_video, samples): """ Test the case when decoder starts with a video file to decode frames, and audio waveform are resampled """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related channels = 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration, ) = tv_result if aframes.numel() > 0: assert samples == asample_rate.item() assert 1 == aframes.size(1) # when audio stream is found duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1]) assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item()) @pytest.mark.parametrize("test_video,config", test_videos.items()) def test_compare_read_video_from_memory_and_file(self, test_video, config): """ Test the case when video is already in memory, and decoder reads data in memory """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) # pass 1: decode all frames using cpp decoder tv_result_memory = torch.ops.video_reader.read_video_from_memory( video_tensor, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) self.check_separate_decoding_result(tv_result_memory, config) # pass 2: decode all frames from file tv_result_file = torch.ops.video_reader.read_video_from_file( full_path, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) self.check_separate_decoding_result(tv_result_file, config) # finally, compare results decoded from memory and file self.compare_decoding_result(tv_result_memory, tv_result_file) @pytest.mark.parametrize("test_video,config", test_videos.items()) def test_read_video_from_memory(self, test_video, config): """ Test the case when video is already in memory, and decoder reads data in memory """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) # pass 1: decode all frames using cpp decoder tv_result = torch.ops.video_reader.read_video_from_memory( video_tensor, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) # pass 2: decode all frames using av pyav_result = _decode_frames_by_av_module(full_path) self.check_separate_decoding_result(tv_result, config) self.compare_decoding_result(tv_result, pyav_result, config) @pytest.mark.parametrize("test_video,config", test_videos.items()) def test_read_video_from_memory_get_pts_only(self, test_video, config): """ Test the case when video is already in memory, and decoder reads data in memory. Compare frame pts between decoding for pts only and full decoding for both pts and frame data """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) # pass 1: decode all frames using cpp decoder tv_result = torch.ops.video_reader.read_video_from_memory( video_tensor, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert abs(config.video_fps - tv_result[3].item()) < 0.01 # pass 2: decode all frames to get PTS only using cpp decoder tv_result_pts_only = torch.ops.video_reader.read_video_from_memory( video_tensor, SEEK_FRAME_MARGIN, 1, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) assert not tv_result_pts_only[0].numel() assert not tv_result_pts_only[5].numel() self.compare_decoding_result(tv_result, tv_result_pts_only) @pytest.mark.parametrize("test_video,config", test_videos.items()) @pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128]) def test_read_video_in_range_from_memory(self, test_video, config, num_frames): """ Test the case when video is already in memory, and decoder reads data in memory. In addition, decoder takes meaningful start- and end PTS as input, and decode frames within that interval """ full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 # pass 1: decode all frames using new decoder tv_result = torch.ops.video_reader.read_video_from_memory( video_tensor, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) ( vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration, ) = tv_result assert abs(config.video_fps - vfps.item()) < 0.01 start_pts_ind_max = vframe_pts.size(0) - num_frames if start_pts_ind_max <= 0: return # randomly pick start pts start_pts_ind = randint(0, start_pts_ind_max) end_pts_ind = start_pts_ind + num_frames - 1 video_start_pts = vframe_pts[start_pts_ind] video_end_pts = vframe_pts[end_pts_ind] video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1] if len(atimebase) > 0: # when audio stream is available audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1] audio_start_pts = _pts_convert( video_start_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(audio_timebase_num.item(), audio_timebase_den.item()), math.floor, ) audio_end_pts = _pts_convert( video_end_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(audio_timebase_num.item(), audio_timebase_den.item()), math.ceil, ) # pass 2: decode frames in the randomly generated range tv_result = torch.ops.video_reader.read_video_from_memory( video_tensor, SEEK_FRAME_MARGIN, 0, # getPtsOnly 1, # readVideoStream width, height, min_dimension, max_dimension, video_start_pts, video_end_pts, video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, audio_start_pts, audio_end_pts, audio_timebase_num, audio_timebase_den, ) # pass 3: decode frames in range using PyAv video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path) video_start_pts_av = _pts_convert( video_start_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_av.numerator, video_timebase_av.denominator), math.floor, ) video_end_pts_av = _pts_convert( video_end_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_av.numerator, video_timebase_av.denominator), math.ceil, ) if audio_timebase_av: audio_start_pts = _pts_convert( video_start_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), math.floor, ) audio_end_pts = _pts_convert( video_end_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), math.ceil, ) pyav_result = _decode_frames_by_av_module( full_path, video_start_pts_av, video_end_pts_av, audio_start_pts, audio_end_pts, ) assert tv_result[0].size(0) == num_frames if pyav_result.vframes.size(0) == num_frames: # if PyAv decodes a different number of video frames, skip # comparing the decoding results between Torchvision video reader # and PyAv self.compare_decoding_result(tv_result, pyav_result, config) @pytest.mark.parametrize("test_video,config", test_videos.items()) def test_probe_video_from_file(self, test_video, config): """ Test the case when decoder probes a video file """ full_path = os.path.join(VIDEO_DIR, test_video) probe_result = torch.ops.video_reader.probe_video_from_file(full_path) self.check_probe_result(probe_result, config) @pytest.mark.parametrize("test_video,config", test_videos.items()) def test_probe_video_from_memory(self, test_video, config): """ Test the case when decoder probes a video in memory """ _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor) self.check_probe_result(probe_result, config) @pytest.mark.parametrize("test_video,config", test_videos.items()) def test_probe_video_from_memory_script(self, test_video, config): scripted_fun = torch.jit.script(io._probe_video_from_memory) assert scripted_fun is not None _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) probe_result = scripted_fun(video_tensor) self.check_meta_result(probe_result, config) @pytest.mark.parametrize("test_video", test_videos.keys()) def test_read_video_from_memory_scripted(self, test_video): """ Test the case when video is already in memory, and decoder reads data in memory """ # video related width, height, min_dimension, max_dimension = 0, 0, 0, 0 video_start_pts, video_end_pts = 0, -1 video_timebase_num, video_timebase_den = 0, 1 # audio related samples, channels = 0, 0 audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 scripted_fun = torch.jit.script(io._read_video_from_memory) assert scripted_fun is not None _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) # decode all frames using cpp decoder scripted_fun( video_tensor, SEEK_FRAME_MARGIN, 1, # readVideoStream width, height, min_dimension, max_dimension, [video_start_pts, video_end_pts], video_timebase_num, video_timebase_den, 1, # readAudioStream samples, channels, [audio_start_pts, audio_end_pts], audio_timebase_num, audio_timebase_den, ) # FUTURE: check value of video / audio frames def test_invalid_file(self): set_video_backend("video_reader") with pytest.raises(RuntimeError): io.read_video("foo.mp4") set_video_backend("pyav") with pytest.raises(RuntimeError): io.read_video("foo.mp4") @pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("backend", ["video_reader", "pyav"]) @pytest.mark.parametrize("start_offset", [0, 500]) @pytest.mark.parametrize("end_offset", [3000, None]) def test_audio_present_pts(self, test_video, backend, start_offset, end_offset): """Test if audio frames are returned with pts unit.""" full_path = os.path.join(VIDEO_DIR, test_video) container = av.open(full_path) if container.streams.audio: set_video_backend(backend) _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts") assert all([dimension > 0 for dimension in audio.shape[:2]]) @pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("backend", ["video_reader", "pyav"]) @pytest.mark.parametrize("start_offset", [0, 0.1]) @pytest.mark.parametrize("end_offset", [0.3, None]) def test_audio_present_sec(self, test_video, backend, start_offset, end_offset): """Test if audio frames are returned with sec unit.""" full_path = os.path.join(VIDEO_DIR, test_video) container = av.open(full_path) if container.streams.audio: set_video_backend(backend) _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec") assert all([dimension > 0 for dimension in audio.shape[:2]]) if __name__ == "__main__": pytest.main([__file__])