123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415 |
- import gc
- import math
- import os
- import re
- import warnings
- from fractions import Fraction
- from typing import Any, Dict, List, Optional, Tuple, Union
- import numpy as np
- import torch
- from ..utils import _log_api_usage_once
- from . import _video_opt
- try:
- import av
- av.logging.set_level(av.logging.ERROR)
- if not hasattr(av.video.frame.VideoFrame, "pict_type"):
- av = ImportError(
- """\
- Your version of PyAV is too old for the necessary video operations in torchvision.
- If you are on Python 3.5, you will have to build from source (the conda-forge
- packages are not up-to-date). See
- https://github.com/mikeboers/PyAV#installation for instructions on how to
- install PyAV on your system.
- """
- )
- except ImportError:
- av = ImportError(
- """\
- PyAV is not installed, and is necessary for the video operations in torchvision.
- See https://github.com/mikeboers/PyAV#installation for instructions on how to
- install PyAV on your system.
- """
- )
- def _check_av_available() -> None:
- if isinstance(av, Exception):
- raise av
- def _av_available() -> bool:
- return not isinstance(av, Exception)
- # PyAV has some reference cycles
- _CALLED_TIMES = 0
- _GC_COLLECTION_INTERVAL = 10
- def write_video(
- filename: str,
- video_array: torch.Tensor,
- fps: float,
- video_codec: str = "libx264",
- options: Optional[Dict[str, Any]] = None,
- audio_array: Optional[torch.Tensor] = None,
- audio_fps: Optional[float] = None,
- audio_codec: Optional[str] = None,
- audio_options: Optional[Dict[str, Any]] = None,
- ) -> None:
- """
- Writes a 4d tensor in [T, H, W, C] format in a video file
- Args:
- filename (str): path where the video will be saved
- video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
- as a uint8 tensor in [T, H, W, C] format
- fps (Number): video frames per second
- video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
- options (Dict): dictionary containing options to be passed into the PyAV video stream
- audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
- and N is the number of samples
- audio_fps (Number): audio sample rate, typically 44100 or 48000
- audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
- audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(write_video)
- _check_av_available()
- video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
- # PyAV does not support floating point numbers with decimal point
- # and will throw OverflowException in case this is not the case
- if isinstance(fps, float):
- fps = np.round(fps)
- with av.open(filename, mode="w") as container:
- stream = container.add_stream(video_codec, rate=fps)
- stream.width = video_array.shape[2]
- stream.height = video_array.shape[1]
- stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
- stream.options = options or {}
- if audio_array is not None:
- audio_format_dtypes = {
- "dbl": "<f8",
- "dblp": "<f8",
- "flt": "<f4",
- "fltp": "<f4",
- "s16": "<i2",
- "s16p": "<i2",
- "s32": "<i4",
- "s32p": "<i4",
- "u8": "u1",
- "u8p": "u1",
- }
- a_stream = container.add_stream(audio_codec, rate=audio_fps)
- a_stream.options = audio_options or {}
- num_channels = audio_array.shape[0]
- audio_layout = "stereo" if num_channels > 1 else "mono"
- audio_sample_fmt = container.streams.audio[0].format.name
- format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
- audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype)
- frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
- frame.sample_rate = audio_fps
- for packet in a_stream.encode(frame):
- container.mux(packet)
- for packet in a_stream.encode():
- container.mux(packet)
- for img in video_array:
- frame = av.VideoFrame.from_ndarray(img, format="rgb24")
- frame.pict_type = "NONE"
- for packet in stream.encode(frame):
- container.mux(packet)
- # Flush stream
- for packet in stream.encode():
- container.mux(packet)
- def _read_from_stream(
- container: "av.container.Container",
- start_offset: float,
- end_offset: float,
- pts_unit: str,
- stream: "av.stream.Stream",
- stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
- ) -> List["av.frame.Frame"]:
- global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
- _CALLED_TIMES += 1
- if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
- gc.collect()
- if pts_unit == "sec":
- # TODO: we should change all of this from ground up to simply take
- # sec and convert to MS in C++
- start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
- if end_offset != float("inf"):
- end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
- else:
- warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
- frames = {}
- should_buffer = True
- max_buffer_size = 5
- if stream.type == "video":
- # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
- # so need to buffer some extra frames to sort everything
- # properly
- extradata = stream.codec_context.extradata
- # overly complicated way of finding if `divx_packed` is set, following
- # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
- if extradata and b"DivX" in extradata:
- # can't use regex directly because of some weird characters sometimes...
- pos = extradata.find(b"DivX")
- d = extradata[pos:]
- o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
- if o is None:
- o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
- if o is not None:
- should_buffer = o.group(3) == b"p"
- seek_offset = start_offset
- # some files don't seek to the right location, so better be safe here
- seek_offset = max(seek_offset - 1, 0)
- if should_buffer:
- # FIXME this is kind of a hack, but we will jump to the previous keyframe
- # so this will be safe
- seek_offset = max(seek_offset - max_buffer_size, 0)
- try:
- # TODO check if stream needs to always be the video stream here or not
- container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
- except av.AVError:
- # TODO add some warnings in this case
- # print("Corrupted file?", container.name)
- return []
- buffer_count = 0
- try:
- for _idx, frame in enumerate(container.decode(**stream_name)):
- frames[frame.pts] = frame
- if frame.pts >= end_offset:
- if should_buffer and buffer_count < max_buffer_size:
- buffer_count += 1
- continue
- break
- except av.AVError:
- # TODO add a warning
- pass
- # ensure that the results are sorted wrt the pts
- result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
- if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
- # if there is no frame that exactly matches the pts of start_offset
- # add the last frame smaller than start_offset, to guarantee that
- # we will have all the necessary data. This is most useful for audio
- preceding_frames = [i for i in frames if i < start_offset]
- if len(preceding_frames) > 0:
- first_frame_pts = max(preceding_frames)
- result.insert(0, frames[first_frame_pts])
- return result
- def _align_audio_frames(
- aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
- ) -> torch.Tensor:
- start, end = audio_frames[0].pts, audio_frames[-1].pts
- total_aframes = aframes.shape[1]
- step_per_aframe = (end - start + 1) / total_aframes
- s_idx = 0
- e_idx = total_aframes
- if start < ref_start:
- s_idx = int((ref_start - start) / step_per_aframe)
- if end > ref_end:
- e_idx = int((ref_end - end) / step_per_aframe)
- return aframes[:, s_idx:e_idx]
- def read_video(
- filename: str,
- start_pts: Union[float, Fraction] = 0,
- end_pts: Optional[Union[float, Fraction]] = None,
- pts_unit: str = "pts",
- output_format: str = "THWC",
- ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
- """
- Reads a video from a file, returning both the video frames and the audio frames
- Args:
- filename (str): path to the video file
- start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
- The start presentation time of the video
- end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
- The end presentation time
- pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
- either 'pts' or 'sec'. Defaults to 'pts'.
- output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
- Returns:
- vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
- aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
- info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(read_video)
- output_format = output_format.upper()
- if output_format not in ("THWC", "TCHW"):
- raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
- from torchvision import get_video_backend
- if not os.path.exists(filename):
- raise RuntimeError(f"File not found: {filename}")
- if get_video_backend() != "pyav":
- vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
- else:
- _check_av_available()
- if end_pts is None:
- end_pts = float("inf")
- if end_pts < start_pts:
- raise ValueError(
- f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
- )
- info = {}
- video_frames = []
- audio_frames = []
- audio_timebase = _video_opt.default_timebase
- try:
- with av.open(filename, metadata_errors="ignore") as container:
- if container.streams.audio:
- audio_timebase = container.streams.audio[0].time_base
- if container.streams.video:
- video_frames = _read_from_stream(
- container,
- start_pts,
- end_pts,
- pts_unit,
- container.streams.video[0],
- {"video": 0},
- )
- video_fps = container.streams.video[0].average_rate
- # guard against potentially corrupted files
- if video_fps is not None:
- info["video_fps"] = float(video_fps)
- if container.streams.audio:
- audio_frames = _read_from_stream(
- container,
- start_pts,
- end_pts,
- pts_unit,
- container.streams.audio[0],
- {"audio": 0},
- )
- info["audio_fps"] = container.streams.audio[0].rate
- except av.AVError:
- # TODO raise a warning?
- pass
- vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
- aframes_list = [frame.to_ndarray() for frame in audio_frames]
- if vframes_list:
- vframes = torch.as_tensor(np.stack(vframes_list))
- else:
- vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
- if aframes_list:
- aframes = np.concatenate(aframes_list, 1)
- aframes = torch.as_tensor(aframes)
- if pts_unit == "sec":
- start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
- if end_pts != float("inf"):
- end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
- aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
- else:
- aframes = torch.empty((1, 0), dtype=torch.float32)
- if output_format == "TCHW":
- # [T,H,W,C] --> [T,C,H,W]
- vframes = vframes.permute(0, 3, 1, 2)
- return vframes, aframes, info
- def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
- extradata = container.streams[0].codec_context.extradata
- if extradata is None:
- return False
- if b"Lavc" in extradata:
- return True
- return False
- def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
- if _can_read_timestamps_from_packets(container):
- # fast path
- return [x.pts for x in container.demux(video=0) if x.pts is not None]
- else:
- return [x.pts for x in container.decode(video=0) if x.pts is not None]
- def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
- """
- List the video frames timestamps.
- Note that the function decodes the whole video frame-by-frame.
- Args:
- filename (str): path to the video file
- pts_unit (str, optional): unit in which timestamp values will be returned
- either 'pts' or 'sec'. Defaults to 'pts'.
- Returns:
- pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
- presentation timestamps for each one of the frames in the video.
- video_fps (float, optional): the frame rate for the video
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(read_video_timestamps)
- from torchvision import get_video_backend
- if get_video_backend() != "pyav":
- return _video_opt._read_video_timestamps(filename, pts_unit)
- _check_av_available()
- video_fps = None
- pts = []
- try:
- with av.open(filename, metadata_errors="ignore") as container:
- if container.streams.video:
- video_stream = container.streams.video[0]
- video_time_base = video_stream.time_base
- try:
- pts = _decode_video_timestamps(container)
- except av.AVError:
- warnings.warn(f"Failed decoding frames for file {filename}")
- video_fps = float(video_stream.average_rate)
- except av.AVError as e:
- msg = f"Failed to open container for {filename}; Caught error: {e}"
- warnings.warn(msg, RuntimeWarning)
- pts.sort()
- if pts_unit == "sec":
- pts = [x * video_time_base for x in pts]
- return pts, video_fps
|