video_reader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import io
  2. import warnings
  3. from typing import Any, Dict, Iterator, Optional
  4. import torch
  5. from ..utils import _log_api_usage_once
  6. from ._video_opt import _HAS_VIDEO_OPT
  7. if _HAS_VIDEO_OPT:
  8. def _has_video_opt() -> bool:
  9. return True
  10. else:
  11. def _has_video_opt() -> bool:
  12. return False
  13. try:
  14. import av
  15. av.logging.set_level(av.logging.ERROR)
  16. if not hasattr(av.video.frame.VideoFrame, "pict_type"):
  17. av = ImportError(
  18. """\
  19. Your version of PyAV is too old for the necessary video operations in torchvision.
  20. If you are on Python 3.5, you will have to build from source (the conda-forge
  21. packages are not up-to-date). See
  22. https://github.com/mikeboers/PyAV#installation for instructions on how to
  23. install PyAV on your system.
  24. """
  25. )
  26. except ImportError:
  27. av = ImportError(
  28. """\
  29. PyAV is not installed, and is necessary for the video operations in torchvision.
  30. See https://github.com/mikeboers/PyAV#installation for instructions on how to
  31. install PyAV on your system.
  32. """
  33. )
  34. class VideoReader:
  35. """
  36. Fine-grained video-reading API.
  37. Supports frame-by-frame reading of various streams from a single video
  38. container. Much like previous video_reader API it supports the following
  39. backends: video_reader, pyav, and cuda.
  40. Backends can be set via `torchvision.set_video_backend` function.
  41. .. betastatus:: VideoReader class
  42. Example:
  43. The following examples creates a :mod:`VideoReader` object, seeks into 2s
  44. point, and returns a single frame::
  45. import torchvision
  46. video_path = "path_to_a_test_video"
  47. reader = torchvision.io.VideoReader(video_path, "video")
  48. reader.seek(2.0)
  49. frame = next(reader)
  50. :mod:`VideoReader` implements the iterable API, which makes it suitable to
  51. using it in conjunction with :mod:`itertools` for more advanced reading.
  52. As such, we can use a :mod:`VideoReader` instance inside for loops::
  53. reader.seek(2)
  54. for frame in reader:
  55. frames.append(frame['data'])
  56. # additionally, `seek` implements a fluent API, so we can do
  57. for frame in reader.seek(2):
  58. frames.append(frame['data'])
  59. With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
  60. following code::
  61. for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
  62. frames.append(frame['data'])
  63. and similarly, reading 10 frames after the 2s timestamp can be achieved
  64. as follows::
  65. for frame in itertools.islice(reader.seek(2), 10):
  66. frames.append(frame['data'])
  67. .. note::
  68. Each stream descriptor consists of two parts: stream type (e.g. 'video') and
  69. a unique stream id (which are determined by the video encoding).
  70. In this way, if the video container contains multiple
  71. streams of the same type, users can access the one they want.
  72. If only stream type is passed, the decoder auto-detects first stream of that type.
  73. Args:
  74. src (string, bytes object, or tensor): The media source.
  75. If string-type, it must be a file path supported by FFMPEG.
  76. If bytes, should be an in-memory representation of a file supported by FFMPEG.
  77. If Tensor, it is interpreted internally as byte buffer.
  78. It must be one-dimensional, of type ``torch.uint8``.
  79. stream (string, optional): descriptor of the required stream, followed by the stream id,
  80. in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
  81. Currently available options include ``['video', 'audio']``
  82. num_threads (int, optional): number of threads used by the codec to decode video.
  83. Default value (0) enables multithreading with codec-dependent heuristic. The performance
  84. will depend on the version of FFMPEG codecs supported.
  85. path (str, optional):
  86. .. warning:
  87. This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
  88. Please use ``src`` instead.
  89. """
  90. def __init__(
  91. self,
  92. src: str = "",
  93. stream: str = "video",
  94. num_threads: int = 0,
  95. path: Optional[str] = None,
  96. ) -> None:
  97. _log_api_usage_once(self)
  98. from .. import get_video_backend
  99. self.backend = get_video_backend()
  100. if isinstance(src, str):
  101. if src == "":
  102. if path is None:
  103. raise TypeError("src cannot be empty")
  104. src = path
  105. warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")
  106. elif isinstance(src, bytes):
  107. if self.backend in ["cuda"]:
  108. raise RuntimeError(
  109. "VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
  110. )
  111. elif self.backend == "pyav":
  112. src = io.BytesIO(src)
  113. else:
  114. with warnings.catch_warnings():
  115. # Ignore the warning because we actually don't modify the buffer in this function
  116. warnings.filterwarnings("ignore", message="The given buffer is not writable")
  117. src = torch.frombuffer(src, dtype=torch.uint8)
  118. elif isinstance(src, torch.Tensor):
  119. if self.backend in ["cuda", "pyav"]:
  120. raise RuntimeError(
  121. "VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
  122. )
  123. else:
  124. raise TypeError("`src` must be either string, Tensor or bytes object.")
  125. if self.backend == "cuda":
  126. device = torch.device("cuda")
  127. self._c = torch.classes.torchvision.GPUDecoder(src, device)
  128. elif self.backend == "video_reader":
  129. if isinstance(src, str):
  130. self._c = torch.classes.torchvision.Video(src, stream, num_threads)
  131. elif isinstance(src, torch.Tensor):
  132. self._c = torch.classes.torchvision.Video("", "", 0)
  133. self._c.init_from_memory(src, stream, num_threads)
  134. elif self.backend == "pyav":
  135. self.container = av.open(src, metadata_errors="ignore")
  136. # TODO: load metadata
  137. stream_type = stream.split(":")[0]
  138. stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
  139. self.pyav_stream = {stream_type: stream_id}
  140. self._c = self.container.decode(**self.pyav_stream)
  141. # TODO: add extradata exception
  142. else:
  143. raise RuntimeError("Unknown video backend: {}".format(self.backend))
  144. def __next__(self) -> Dict[str, Any]:
  145. """Decodes and returns the next frame of the current stream.
  146. Frames are encoded as a dict with mandatory
  147. data and pts fields, where data is a tensor, and pts is a
  148. presentation timestamp of the frame expressed in seconds
  149. as a float.
  150. Returns:
  151. (dict): a dictionary and containing decoded frame (``data``)
  152. and corresponding timestamp (``pts``) in seconds
  153. """
  154. if self.backend == "cuda":
  155. frame = self._c.next()
  156. if frame.numel() == 0:
  157. raise StopIteration
  158. return {"data": frame, "pts": None}
  159. elif self.backend == "video_reader":
  160. frame, pts = self._c.next()
  161. else:
  162. try:
  163. frame = next(self._c)
  164. pts = float(frame.pts * frame.time_base)
  165. if "video" in self.pyav_stream:
  166. frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
  167. elif "audio" in self.pyav_stream:
  168. frame = torch.tensor(frame.to_ndarray()).permute(1, 0)
  169. else:
  170. frame = None
  171. except av.error.EOFError:
  172. raise StopIteration
  173. if frame.numel() == 0:
  174. raise StopIteration
  175. return {"data": frame, "pts": pts}
  176. def __iter__(self) -> Iterator[Dict[str, Any]]:
  177. return self
  178. def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
  179. """Seek within current stream.
  180. Args:
  181. time_s (float): seek time in seconds
  182. keyframes_only (bool): allow to seek only to keyframes
  183. .. note::
  184. Current implementation is the so-called precise seek. This
  185. means following seek, call to :mod:`next()` will return the
  186. frame with the exact timestamp if it exists or
  187. the first frame with timestamp larger than ``time_s``.
  188. """
  189. if self.backend in ["cuda", "video_reader"]:
  190. self._c.seek(time_s, keyframes_only)
  191. else:
  192. # handle special case as pyav doesn't catch it
  193. if time_s < 0:
  194. time_s = 0
  195. temp_str = self.container.streams.get(**self.pyav_stream)[0]
  196. offset = int(round(time_s / temp_str.time_base))
  197. if not keyframes_only:
  198. warnings.warn("Accurate seek is not implemented for pyav backend")
  199. self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
  200. self._c = self.container.decode(**self.pyav_stream)
  201. return self
  202. def get_metadata(self) -> Dict[str, Any]:
  203. """Returns video metadata
  204. Returns:
  205. (dict): dictionary containing duration and frame rate for every stream
  206. """
  207. if self.backend == "pyav":
  208. metadata = {} # type: Dict[str, Any]
  209. for stream in self.container.streams:
  210. if stream.type not in metadata:
  211. if stream.type == "video":
  212. rate_n = "fps"
  213. else:
  214. rate_n = "framerate"
  215. metadata[stream.type] = {rate_n: [], "duration": []}
  216. rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate
  217. metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
  218. metadata[stream.type][rate_n].append(float(rate))
  219. return metadata
  220. return self._c.get_metadata()
  221. def set_current_stream(self, stream: str) -> bool:
  222. """Set current stream.
  223. Explicitly define the stream we are operating on.
  224. Args:
  225. stream (string): descriptor of the required stream. Defaults to ``"video:0"``
  226. Currently available stream types include ``['video', 'audio']``.
  227. Each descriptor consists of two parts: stream type (e.g. 'video') and
  228. a unique stream id (which are determined by video encoding).
  229. In this way, if the video container contains multiple
  230. streams of the same type, users can access the one they want.
  231. If only stream type is passed, the decoder auto-detects first stream
  232. of that type and returns it.
  233. Returns:
  234. (bool): True on success, False otherwise
  235. """
  236. if self.backend == "cuda":
  237. warnings.warn("GPU decoding only works with video stream.")
  238. if self.backend == "pyav":
  239. stream_type = stream.split(":")[0]
  240. stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
  241. self.pyav_stream = {stream_type: stream_id}
  242. self._c = self.container.decode(**self.pyav_stream)
  243. return True
  244. return self._c.set_current_stream(stream)