123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254 |
- import collections
- import math
- import os
- from fractions import Fraction
- import numpy as np
- import pytest
- import torch
- import torchvision.io as io
- from common_utils import assert_equal
- from numpy.random import randint
- from pytest import approx
- from torchvision import set_video_backend
- from torchvision.io import _HAS_VIDEO_OPT
- try:
- import av
- # Do a version test too
- io.video._check_av_available()
- except ImportError:
- av = None
- VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
- CheckerConfig = [
- "duration",
- "video_fps",
- "audio_sample_rate",
- # We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
- # slightly different between TorchVision decoder and PyAv decoder. So omit it during check
- "check_aframes",
- "check_aframe_pts",
- ]
- GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
- all_check_config = GroundTruth(
- duration=0,
- video_fps=0,
- audio_sample_rate=0,
- check_aframes=True,
- check_aframe_pts=True,
- )
- test_videos = {
- "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
- duration=2.0,
- video_fps=30.0,
- audio_sample_rate=None,
- check_aframes=True,
- check_aframe_pts=True,
- ),
- "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
- duration=2.0,
- video_fps=30.0,
- audio_sample_rate=None,
- check_aframes=True,
- check_aframe_pts=True,
- ),
- "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
- duration=2.0,
- video_fps=30.0,
- audio_sample_rate=None,
- check_aframes=True,
- check_aframe_pts=True,
- ),
- "v_SoccerJuggling_g23_c01.avi": GroundTruth(
- duration=8.0,
- video_fps=29.97,
- audio_sample_rate=None,
- check_aframes=True,
- check_aframe_pts=True,
- ),
- "v_SoccerJuggling_g24_c01.avi": GroundTruth(
- duration=8.0,
- video_fps=29.97,
- audio_sample_rate=None,
- check_aframes=True,
- check_aframe_pts=True,
- ),
- "R6llTwEh07w.mp4": GroundTruth(
- duration=10.0,
- video_fps=30.0,
- audio_sample_rate=44100,
- # PyAv miss one audio frame at the beginning (pts=0)
- check_aframes=False,
- check_aframe_pts=False,
- ),
- "SOX5yA1l24A.mp4": GroundTruth(
- duration=11.0,
- video_fps=29.97,
- audio_sample_rate=48000,
- # PyAv miss one audio frame at the beginning (pts=0)
- check_aframes=False,
- check_aframe_pts=False,
- ),
- "WUzgd7C1pWA.mp4": GroundTruth(
- duration=11.0,
- video_fps=29.97,
- audio_sample_rate=48000,
- # PyAv miss one audio frame at the beginning (pts=0)
- check_aframes=False,
- check_aframe_pts=False,
- ),
- }
- DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase")
- # av_seek_frame is imprecise so seek to a timestamp earlier by a margin
- # The unit of margin is second
- SEEK_FRAME_MARGIN = 0.25
- def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
- """
- Args:
- container: pyav container
- start_pts/end_pts: the starting/ending Presentation TimeStamp where
- frames are read
- stream: pyav stream
- stream_name: a dictionary of streams. For example, {"video": 0} means
- video stream at stream index 0
- buffer_size: pts of frames decoded by PyAv is not guaranteed to be in
- ascending order. We need to decode more frames even when we meet end
- pts
- """
- # seeking in the stream is imprecise. Thus, seek to an earlier PTS by a margin
- margin = 1
- seek_offset = max(start_pts - margin, 0)
- container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
- frames = {}
- buffer_count = 0
- for frame in container.decode(**stream_name):
- if frame.pts < start_pts:
- continue
- if frame.pts <= end_pts:
- frames[frame.pts] = frame
- else:
- buffer_count += 1
- if buffer_count >= buffer_size:
- break
- result = [frames[pts] for pts in sorted(frames)]
- return result
- def _get_timebase_by_av_module(full_path):
- container = av.open(full_path)
- video_time_base = container.streams.video[0].time_base
- if container.streams.audio:
- audio_time_base = container.streams.audio[0].time_base
- else:
- audio_time_base = None
- return video_time_base, audio_time_base
- def _fraction_to_tensor(fraction):
- ret = torch.zeros([2], dtype=torch.int32)
- ret[0] = fraction.numerator
- ret[1] = fraction.denominator
- return ret
- def _decode_frames_by_av_module(
- full_path,
- video_start_pts=0,
- video_end_pts=None,
- audio_start_pts=0,
- audio_end_pts=None,
- ):
- """
- Use PyAv to decode video frames. This provides a reference for our decoder
- to compare the decoding results.
- Input arguments:
- full_path: video file path
- video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where
- frames are read
- """
- if video_end_pts is None:
- video_end_pts = float("inf")
- if audio_end_pts is None:
- audio_end_pts = float("inf")
- container = av.open(full_path)
- video_frames = []
- vtimebase = torch.zeros([0], dtype=torch.int32)
- if container.streams.video:
- video_frames = _read_from_stream(
- container,
- video_start_pts,
- video_end_pts,
- container.streams.video[0],
- {"video": 0},
- )
- # container.streams.video[0].average_rate is not a reliable estimator of
- # frame rate. It can be wrong for certain codec, such as VP80
- # So we do not return video fps here
- vtimebase = _fraction_to_tensor(container.streams.video[0].time_base)
- audio_frames = []
- atimebase = torch.zeros([0], dtype=torch.int32)
- if container.streams.audio:
- audio_frames = _read_from_stream(
- container,
- audio_start_pts,
- audio_end_pts,
- container.streams.audio[0],
- {"audio": 0},
- )
- atimebase = _fraction_to_tensor(container.streams.audio[0].time_base)
- container.close()
- vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
- vframes = torch.as_tensor(np.stack(vframes))
- vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64)
- aframes = [frame.to_ndarray() for frame in audio_frames]
- if aframes:
- aframes = np.transpose(np.concatenate(aframes, axis=1))
- aframes = torch.as_tensor(aframes)
- else:
- aframes = torch.empty((1, 0), dtype=torch.float32)
- aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64)
- return DecoderResult(
- vframes=vframes,
- vframe_pts=vframe_pts,
- vtimebase=vtimebase,
- aframes=aframes,
- aframe_pts=aframe_pts,
- atimebase=atimebase,
- )
- def _pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
- """convert pts between different time bases
- Args:
- pts: presentation timestamp, float
- timebase_from: original timebase. Fraction
- timebase_to: new timebase. Fraction
- round_func: rounding function.
- """
- new_pts = Fraction(pts, 1) * timebase_from / timebase_to
- return int(round_func(new_pts))
- def _get_video_tensor(video_dir, video_file):
- """open a video file, and represent the video data by a PT tensor"""
- full_path = os.path.join(video_dir, video_file)
- assert os.path.exists(full_path), "File not found: %s" % full_path
- with open(full_path, "rb") as fp:
- video_tensor = torch.frombuffer(fp.read(), dtype=torch.uint8)
- return full_path, video_tensor
- @pytest.mark.skipif(av is None, reason="PyAV unavailable")
- @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
- class TestVideoReader:
- def check_separate_decoding_result(self, tv_result, config):
- """check the decoding results from TorchVision decoder"""
- (
- vframes,
- vframe_pts,
- vtimebase,
- vfps,
- vduration,
- aframes,
- aframe_pts,
- atimebase,
- asample_rate,
- aduration,
- ) = tv_result
- video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
- assert video_duration == approx(config.duration, abs=0.5)
- assert vfps.item() == approx(config.video_fps, abs=0.5)
- if asample_rate.numel() > 0:
- assert asample_rate.item() == config.audio_sample_rate
- audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
- assert audio_duration == approx(config.duration, abs=0.5)
- # check if pts of video frames are sorted in ascending order
- for i in range(len(vframe_pts) - 1):
- assert vframe_pts[i] < vframe_pts[i + 1]
- if len(aframe_pts) > 1:
- # check if pts of audio frames are sorted in ascending order
- for i in range(len(aframe_pts) - 1):
- assert aframe_pts[i] < aframe_pts[i + 1]
- def check_probe_result(self, result, config):
- vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
- video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item())
- assert video_duration == approx(config.duration, abs=0.5)
- assert vfps.item() == approx(config.video_fps, abs=0.5)
- if asample_rate.numel() > 0:
- assert asample_rate.item() == config.audio_sample_rate
- audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item())
- assert audio_duration == approx(config.duration, abs=0.5)
- def check_meta_result(self, result, config):
- assert result.video_duration == approx(config.duration, abs=0.5)
- assert result.video_fps == approx(config.video_fps, abs=0.5)
- if result.has_audio > 0:
- assert result.audio_sample_rate == config.audio_sample_rate
- assert result.audio_duration == approx(config.duration, abs=0.5)
- def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
- """
- Compare decoding results from two sources.
- Args:
- tv_result: decoding results from TorchVision decoder
- ref_result: reference decoding results which can be from either PyAv
- decoder or TorchVision decoder with getPtsOnly = 1
- config: config of decoding results checker
- """
- (
- vframes,
- vframe_pts,
- vtimebase,
- _vfps,
- _vduration,
- aframes,
- aframe_pts,
- atimebase,
- _asample_rate,
- _aduration,
- ) = tv_result
- if isinstance(ref_result, list):
- # the ref_result is from new video_reader decoder
- ref_result = DecoderResult(
- vframes=ref_result[0],
- vframe_pts=ref_result[1],
- vtimebase=ref_result[2],
- aframes=ref_result[5],
- aframe_pts=ref_result[6],
- atimebase=ref_result[7],
- )
- if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
- mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
- assert mean_delta == approx(0.0, abs=8.0)
- mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
- assert mean_delta == approx(0.0, abs=1.0)
- assert_equal(vtimebase, ref_result.vtimebase)
- if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
- """Audio stream is available and audio frame is required to return
- from decoder"""
- assert_equal(aframes, ref_result.aframes)
- if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
- """Audio stream is available"""
- assert_equal(aframe_pts, ref_result.aframe_pts)
- assert_equal(atimebase, ref_result.atimebase)
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_stress_test_read_video_from_file(self, test_video):
- pytest.skip(
- "This stress test will iteratively decode the same set of videos."
- "It helps to detect memory leak but it takes lots of time to run."
- "By default, it is disabled"
- )
- num_iter = 10000
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- for _i in range(num_iter):
- full_path = os.path.join(VIDEO_DIR, test_video)
- # pass 1: decode all frames using new decoder
- torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- def test_read_video_from_file(self, test_video, config):
- """
- Test the case when decoder starts with a video file to decode frames.
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- # pass 1: decode all frames using new decoder
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- # pass 2: decode all frames using av
- pyav_result = _decode_frames_by_av_module(full_path)
- # check results from TorchVision decoder
- self.check_separate_decoding_result(tv_result, config)
- # compare decoding results
- self.compare_decoding_result(tv_result, pyav_result, config)
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- @pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)])
- def test_read_video_from_file_read_single_stream_only(
- self, test_video, config, read_video_stream, read_audio_stream
- ):
- """
- Test the case when decoder starts with a video file to decode frames, and
- only reads video stream and ignores audio stream
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- # decode all frames using new decoder
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- read_video_stream,
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- read_audio_stream,
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- (
- vframes,
- vframe_pts,
- vtimebase,
- vfps,
- vduration,
- aframes,
- aframe_pts,
- atimebase,
- asample_rate,
- aduration,
- ) = tv_result
- assert (vframes.numel() > 0) is bool(read_video_stream)
- assert (vframe_pts.numel() > 0) is bool(read_video_stream)
- assert (vtimebase.numel() > 0) is bool(read_video_stream)
- assert (vfps.numel() > 0) is bool(read_video_stream)
- expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None
- assert (aframes.numel() > 0) is bool(expect_audio_data)
- assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
- assert (atimebase.numel() > 0) is bool(expect_audio_data)
- assert (asample_rate.numel() > 0) is bool(expect_audio_data)
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_read_video_from_file_rescale_min_dimension(self, test_video):
- """
- Test the case when decoder starts with a video file to decode frames, and
- video min dimension between height and width is set.
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 128, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_read_video_from_file_rescale_max_dimension(self, test_video):
- """
- Test the case when decoder starts with a video file to decode frames, and
- video min dimension between height and width is set.
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 85
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video):
- """
- Test the case when decoder starts with a video file to decode frames, and
- video min dimension between height and width is set.
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 64, 85
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
- assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_read_video_from_file_rescale_width(self, test_video):
- """
- Test the case when decoder starts with a video file to decode frames, and
- video width is set.
- """
- # video related
- width, height, min_dimension, max_dimension = 256, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert tv_result[0].size(2) == width
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_read_video_from_file_rescale_height(self, test_video):
- """
- Test the case when decoder starts with a video file to decode frames, and
- video height is set.
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 224, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert tv_result[0].size(1) == height
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_read_video_from_file_rescale_width_and_height(self, test_video):
- """
- Test the case when decoder starts with a video file to decode frames, and
- both video height and width are set.
- """
- # video related
- width, height, min_dimension, max_dimension = 320, 240, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert tv_result[0].size(1) == height
- assert tv_result[0].size(2) == width
- @pytest.mark.parametrize("test_video", test_videos.keys())
- @pytest.mark.parametrize("samples", [9600, 96000])
- def test_read_video_from_file_audio_resampling(self, test_video, samples):
- """
- Test the case when decoder starts with a video file to decode frames, and
- audio waveform are resampled
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- channels = 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path = os.path.join(VIDEO_DIR, test_video)
- tv_result = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- (
- vframes,
- vframe_pts,
- vtimebase,
- vfps,
- vduration,
- aframes,
- aframe_pts,
- atimebase,
- asample_rate,
- aduration,
- ) = tv_result
- if aframes.numel() > 0:
- assert samples == asample_rate.item()
- assert 1 == aframes.size(1)
- # when audio stream is found
- duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
- assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- def test_compare_read_video_from_memory_and_file(self, test_video, config):
- """
- Test the case when video is already in memory, and decoder reads data in memory
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
- # pass 1: decode all frames using cpp decoder
- tv_result_memory = torch.ops.video_reader.read_video_from_memory(
- video_tensor,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- self.check_separate_decoding_result(tv_result_memory, config)
- # pass 2: decode all frames from file
- tv_result_file = torch.ops.video_reader.read_video_from_file(
- full_path,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- self.check_separate_decoding_result(tv_result_file, config)
- # finally, compare results decoded from memory and file
- self.compare_decoding_result(tv_result_memory, tv_result_file)
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- def test_read_video_from_memory(self, test_video, config):
- """
- Test the case when video is already in memory, and decoder reads data in memory
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
- # pass 1: decode all frames using cpp decoder
- tv_result = torch.ops.video_reader.read_video_from_memory(
- video_tensor,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- # pass 2: decode all frames using av
- pyav_result = _decode_frames_by_av_module(full_path)
- self.check_separate_decoding_result(tv_result, config)
- self.compare_decoding_result(tv_result, pyav_result, config)
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- def test_read_video_from_memory_get_pts_only(self, test_video, config):
- """
- Test the case when video is already in memory, and decoder reads data in memory.
- Compare frame pts between decoding for pts only and full decoding
- for both pts and frame data
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
- # pass 1: decode all frames using cpp decoder
- tv_result = torch.ops.video_reader.read_video_from_memory(
- video_tensor,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert abs(config.video_fps - tv_result[3].item()) < 0.01
- # pass 2: decode all frames to get PTS only using cpp decoder
- tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
- video_tensor,
- SEEK_FRAME_MARGIN,
- 1, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- assert not tv_result_pts_only[0].numel()
- assert not tv_result_pts_only[5].numel()
- self.compare_decoding_result(tv_result, tv_result_pts_only)
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- @pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128])
- def test_read_video_in_range_from_memory(self, test_video, config, num_frames):
- """
- Test the case when video is already in memory, and decoder reads data in memory.
- In addition, decoder takes meaningful start- and end PTS as input, and decode
- frames within that interval
- """
- full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- # pass 1: decode all frames using new decoder
- tv_result = torch.ops.video_reader.read_video_from_memory(
- video_tensor,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- (
- vframes,
- vframe_pts,
- vtimebase,
- vfps,
- vduration,
- aframes,
- aframe_pts,
- atimebase,
- asample_rate,
- aduration,
- ) = tv_result
- assert abs(config.video_fps - vfps.item()) < 0.01
- start_pts_ind_max = vframe_pts.size(0) - num_frames
- if start_pts_ind_max <= 0:
- return
- # randomly pick start pts
- start_pts_ind = randint(0, start_pts_ind_max)
- end_pts_ind = start_pts_ind + num_frames - 1
- video_start_pts = vframe_pts[start_pts_ind]
- video_end_pts = vframe_pts[end_pts_ind]
- video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
- if len(atimebase) > 0:
- # when audio stream is available
- audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
- audio_start_pts = _pts_convert(
- video_start_pts.item(),
- Fraction(video_timebase_num.item(), video_timebase_den.item()),
- Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
- math.floor,
- )
- audio_end_pts = _pts_convert(
- video_end_pts.item(),
- Fraction(video_timebase_num.item(), video_timebase_den.item()),
- Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
- math.ceil,
- )
- # pass 2: decode frames in the randomly generated range
- tv_result = torch.ops.video_reader.read_video_from_memory(
- video_tensor,
- SEEK_FRAME_MARGIN,
- 0, # getPtsOnly
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- video_start_pts,
- video_end_pts,
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- audio_start_pts,
- audio_end_pts,
- audio_timebase_num,
- audio_timebase_den,
- )
- # pass 3: decode frames in range using PyAv
- video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
- video_start_pts_av = _pts_convert(
- video_start_pts.item(),
- Fraction(video_timebase_num.item(), video_timebase_den.item()),
- Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
- math.floor,
- )
- video_end_pts_av = _pts_convert(
- video_end_pts.item(),
- Fraction(video_timebase_num.item(), video_timebase_den.item()),
- Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
- math.ceil,
- )
- if audio_timebase_av:
- audio_start_pts = _pts_convert(
- video_start_pts.item(),
- Fraction(video_timebase_num.item(), video_timebase_den.item()),
- Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
- math.floor,
- )
- audio_end_pts = _pts_convert(
- video_end_pts.item(),
- Fraction(video_timebase_num.item(), video_timebase_den.item()),
- Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
- math.ceil,
- )
- pyav_result = _decode_frames_by_av_module(
- full_path,
- video_start_pts_av,
- video_end_pts_av,
- audio_start_pts,
- audio_end_pts,
- )
- assert tv_result[0].size(0) == num_frames
- if pyav_result.vframes.size(0) == num_frames:
- # if PyAv decodes a different number of video frames, skip
- # comparing the decoding results between Torchvision video reader
- # and PyAv
- self.compare_decoding_result(tv_result, pyav_result, config)
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- def test_probe_video_from_file(self, test_video, config):
- """
- Test the case when decoder probes a video file
- """
- full_path = os.path.join(VIDEO_DIR, test_video)
- probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
- self.check_probe_result(probe_result, config)
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- def test_probe_video_from_memory(self, test_video, config):
- """
- Test the case when decoder probes a video in memory
- """
- _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
- probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
- self.check_probe_result(probe_result, config)
- @pytest.mark.parametrize("test_video,config", test_videos.items())
- def test_probe_video_from_memory_script(self, test_video, config):
- scripted_fun = torch.jit.script(io._probe_video_from_memory)
- assert scripted_fun is not None
- _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
- probe_result = scripted_fun(video_tensor)
- self.check_meta_result(probe_result, config)
- @pytest.mark.parametrize("test_video", test_videos.keys())
- def test_read_video_from_memory_scripted(self, test_video):
- """
- Test the case when video is already in memory, and decoder reads data in memory
- """
- # video related
- width, height, min_dimension, max_dimension = 0, 0, 0, 0
- video_start_pts, video_end_pts = 0, -1
- video_timebase_num, video_timebase_den = 0, 1
- # audio related
- samples, channels = 0, 0
- audio_start_pts, audio_end_pts = 0, -1
- audio_timebase_num, audio_timebase_den = 0, 1
- scripted_fun = torch.jit.script(io._read_video_from_memory)
- assert scripted_fun is not None
- _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
- # decode all frames using cpp decoder
- scripted_fun(
- video_tensor,
- SEEK_FRAME_MARGIN,
- 1, # readVideoStream
- width,
- height,
- min_dimension,
- max_dimension,
- [video_start_pts, video_end_pts],
- video_timebase_num,
- video_timebase_den,
- 1, # readAudioStream
- samples,
- channels,
- [audio_start_pts, audio_end_pts],
- audio_timebase_num,
- audio_timebase_den,
- )
- # FUTURE: check value of video / audio frames
- def test_invalid_file(self):
- set_video_backend("video_reader")
- with pytest.raises(RuntimeError):
- io.read_video("foo.mp4")
- set_video_backend("pyav")
- with pytest.raises(RuntimeError):
- io.read_video("foo.mp4")
- @pytest.mark.parametrize("test_video", test_videos.keys())
- @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
- @pytest.mark.parametrize("start_offset", [0, 500])
- @pytest.mark.parametrize("end_offset", [3000, None])
- def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
- """Test if audio frames are returned with pts unit."""
- full_path = os.path.join(VIDEO_DIR, test_video)
- container = av.open(full_path)
- if container.streams.audio:
- set_video_backend(backend)
- _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
- assert all([dimension > 0 for dimension in audio.shape[:2]])
- @pytest.mark.parametrize("test_video", test_videos.keys())
- @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
- @pytest.mark.parametrize("start_offset", [0, 0.1])
- @pytest.mark.parametrize("end_offset", [0.3, None])
- def test_audio_present_sec(self, test_video, backend, start_offset, end_offset):
- """Test if audio frames are returned with sec unit."""
- full_path = os.path.join(VIDEO_DIR, test_video)
- container = av.open(full_path)
- if container.streams.audio:
- set_video_backend(backend)
- _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
- assert all([dimension > 0 for dimension in audio.shape[:2]])
- if __name__ == "__main__":
- pytest.main([__file__])
|