_video_opt.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. import math
  2. import warnings
  3. from fractions import Fraction
  4. from typing import Dict, List, Optional, Tuple, Union
  5. import torch
  6. from ..extension import _load_library
  7. try:
  8. _load_library("video_reader")
  9. _HAS_VIDEO_OPT = True
  10. except (ImportError, OSError):
  11. _HAS_VIDEO_OPT = False
  12. default_timebase = Fraction(0, 1)
  13. # simple class for torch scripting
  14. # the complex Fraction class from fractions module is not scriptable
  15. class Timebase:
  16. __annotations__ = {"numerator": int, "denominator": int}
  17. __slots__ = ["numerator", "denominator"]
  18. def __init__(
  19. self,
  20. numerator: int,
  21. denominator: int,
  22. ) -> None:
  23. self.numerator = numerator
  24. self.denominator = denominator
  25. class VideoMetaData:
  26. __annotations__ = {
  27. "has_video": bool,
  28. "video_timebase": Timebase,
  29. "video_duration": float,
  30. "video_fps": float,
  31. "has_audio": bool,
  32. "audio_timebase": Timebase,
  33. "audio_duration": float,
  34. "audio_sample_rate": float,
  35. }
  36. __slots__ = [
  37. "has_video",
  38. "video_timebase",
  39. "video_duration",
  40. "video_fps",
  41. "has_audio",
  42. "audio_timebase",
  43. "audio_duration",
  44. "audio_sample_rate",
  45. ]
  46. def __init__(self) -> None:
  47. self.has_video = False
  48. self.video_timebase = Timebase(0, 1)
  49. self.video_duration = 0.0
  50. self.video_fps = 0.0
  51. self.has_audio = False
  52. self.audio_timebase = Timebase(0, 1)
  53. self.audio_duration = 0.0
  54. self.audio_sample_rate = 0.0
  55. def _validate_pts(pts_range: Tuple[int, int]) -> None:
  56. if pts_range[0] > pts_range[1] > 0:
  57. raise ValueError(
  58. f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
  59. )
  60. def _fill_info(
  61. vtimebase: torch.Tensor,
  62. vfps: torch.Tensor,
  63. vduration: torch.Tensor,
  64. atimebase: torch.Tensor,
  65. asample_rate: torch.Tensor,
  66. aduration: torch.Tensor,
  67. ) -> VideoMetaData:
  68. """
  69. Build update VideoMetaData struct with info about the video
  70. """
  71. meta = VideoMetaData()
  72. if vtimebase.numel() > 0:
  73. meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
  74. timebase = vtimebase[0].item() / float(vtimebase[1].item())
  75. if vduration.numel() > 0:
  76. meta.has_video = True
  77. meta.video_duration = float(vduration.item()) * timebase
  78. if vfps.numel() > 0:
  79. meta.video_fps = float(vfps.item())
  80. if atimebase.numel() > 0:
  81. meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
  82. timebase = atimebase[0].item() / float(atimebase[1].item())
  83. if aduration.numel() > 0:
  84. meta.has_audio = True
  85. meta.audio_duration = float(aduration.item()) * timebase
  86. if asample_rate.numel() > 0:
  87. meta.audio_sample_rate = float(asample_rate.item())
  88. return meta
  89. def _align_audio_frames(
  90. aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
  91. ) -> torch.Tensor:
  92. start, end = aframe_pts[0], aframe_pts[-1]
  93. num_samples = aframes.size(0)
  94. step_per_aframe = float(end - start + 1) / float(num_samples)
  95. s_idx = 0
  96. e_idx = num_samples
  97. if start < audio_pts_range[0]:
  98. s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
  99. if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
  100. e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
  101. return aframes[s_idx:e_idx, :]
  102. def _read_video_from_file(
  103. filename: str,
  104. seek_frame_margin: float = 0.25,
  105. read_video_stream: bool = True,
  106. video_width: int = 0,
  107. video_height: int = 0,
  108. video_min_dimension: int = 0,
  109. video_max_dimension: int = 0,
  110. video_pts_range: Tuple[int, int] = (0, -1),
  111. video_timebase: Fraction = default_timebase,
  112. read_audio_stream: bool = True,
  113. audio_samples: int = 0,
  114. audio_channels: int = 0,
  115. audio_pts_range: Tuple[int, int] = (0, -1),
  116. audio_timebase: Fraction = default_timebase,
  117. ) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
  118. """
  119. Reads a video from a file, returning both the video frames and the audio frames
  120. Args:
  121. filename (str): path to the video file
  122. seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
  123. when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
  124. read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
  125. video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
  126. the size of decoded frames:
  127. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  128. and video_max_dimension = 0, keep the original frame resolution
  129. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  130. and video_max_dimension = 0, keep the aspect ratio and resize the
  131. frame so that shorter edge size is video_min_dimension
  132. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  133. and video_max_dimension != 0, keep the aspect ratio and resize
  134. the frame so that longer edge size is video_max_dimension
  135. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  136. and video_max_dimension != 0, resize the frame so that shorter
  137. edge size is video_min_dimension, and longer edge size is
  138. video_max_dimension. The aspect ratio may not be preserved
  139. - When video_width = 0, video_height != 0, video_min_dimension = 0,
  140. and video_max_dimension = 0, keep the aspect ratio and resize
  141. the frame so that frame video_height is $video_height
  142. - When video_width != 0, video_height == 0, video_min_dimension = 0,
  143. and video_max_dimension = 0, keep the aspect ratio and resize
  144. the frame so that frame video_width is $video_width
  145. - When video_width != 0, video_height != 0, video_min_dimension = 0,
  146. and video_max_dimension = 0, resize the frame so that frame
  147. video_width and video_height are set to $video_width and
  148. $video_height, respectively
  149. video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
  150. video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
  151. read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
  152. audio_samples (int, optional): audio sampling rate
  153. audio_channels (int optional): audio channels
  154. audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
  155. audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
  156. Returns
  157. vframes (Tensor[T, H, W, C]): the `T` video frames
  158. aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
  159. `K` is the number of audio_channels
  160. info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
  161. and audio_fps (int)
  162. """
  163. _validate_pts(video_pts_range)
  164. _validate_pts(audio_pts_range)
  165. result = torch.ops.video_reader.read_video_from_file(
  166. filename,
  167. seek_frame_margin,
  168. 0, # getPtsOnly
  169. read_video_stream,
  170. video_width,
  171. video_height,
  172. video_min_dimension,
  173. video_max_dimension,
  174. video_pts_range[0],
  175. video_pts_range[1],
  176. video_timebase.numerator,
  177. video_timebase.denominator,
  178. read_audio_stream,
  179. audio_samples,
  180. audio_channels,
  181. audio_pts_range[0],
  182. audio_pts_range[1],
  183. audio_timebase.numerator,
  184. audio_timebase.denominator,
  185. )
  186. vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
  187. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  188. if aframes.numel() > 0:
  189. # when audio stream is found
  190. aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
  191. return vframes, aframes, info
  192. def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
  193. """
  194. Decode all video- and audio frames in the video. Only pts
  195. (presentation timestamp) is returned. The actual frame pixel data is not
  196. copied. Thus, it is much faster than read_video(...)
  197. """
  198. result = torch.ops.video_reader.read_video_from_file(
  199. filename,
  200. 0, # seek_frame_margin
  201. 1, # getPtsOnly
  202. 1, # read_video_stream
  203. 0, # video_width
  204. 0, # video_height
  205. 0, # video_min_dimension
  206. 0, # video_max_dimension
  207. 0, # video_start_pts
  208. -1, # video_end_pts
  209. 0, # video_timebase_num
  210. 1, # video_timebase_den
  211. 1, # read_audio_stream
  212. 0, # audio_samples
  213. 0, # audio_channels
  214. 0, # audio_start_pts
  215. -1, # audio_end_pts
  216. 0, # audio_timebase_num
  217. 1, # audio_timebase_den
  218. )
  219. _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
  220. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  221. vframe_pts = vframe_pts.numpy().tolist()
  222. aframe_pts = aframe_pts.numpy().tolist()
  223. return vframe_pts, aframe_pts, info
  224. def _probe_video_from_file(filename: str) -> VideoMetaData:
  225. """
  226. Probe a video file and return VideoMetaData with info about the video
  227. """
  228. result = torch.ops.video_reader.probe_video_from_file(filename)
  229. vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
  230. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  231. return info
  232. def _read_video_from_memory(
  233. video_data: torch.Tensor,
  234. seek_frame_margin: float = 0.25,
  235. read_video_stream: int = 1,
  236. video_width: int = 0,
  237. video_height: int = 0,
  238. video_min_dimension: int = 0,
  239. video_max_dimension: int = 0,
  240. video_pts_range: Tuple[int, int] = (0, -1),
  241. video_timebase_numerator: int = 0,
  242. video_timebase_denominator: int = 1,
  243. read_audio_stream: int = 1,
  244. audio_samples: int = 0,
  245. audio_channels: int = 0,
  246. audio_pts_range: Tuple[int, int] = (0, -1),
  247. audio_timebase_numerator: int = 0,
  248. audio_timebase_denominator: int = 1,
  249. ) -> Tuple[torch.Tensor, torch.Tensor]:
  250. """
  251. Reads a video from memory, returning both the video frames as the audio frames
  252. This function is torchscriptable.
  253. Args:
  254. video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
  255. compressed video content stored in either 1) torch.Tensor 2) python bytes
  256. seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
  257. Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
  258. read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
  259. video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
  260. the size of decoded frames:
  261. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  262. and video_max_dimension = 0, keep the original frame resolution
  263. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  264. and video_max_dimension = 0, keep the aspect ratio and resize the
  265. frame so that shorter edge size is video_min_dimension
  266. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  267. and video_max_dimension != 0, keep the aspect ratio and resize
  268. the frame so that longer edge size is video_max_dimension
  269. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  270. and video_max_dimension != 0, resize the frame so that shorter
  271. edge size is video_min_dimension, and longer edge size is
  272. video_max_dimension. The aspect ratio may not be preserved
  273. - When video_width = 0, video_height != 0, video_min_dimension = 0,
  274. and video_max_dimension = 0, keep the aspect ratio and resize
  275. the frame so that frame video_height is $video_height
  276. - When video_width != 0, video_height == 0, video_min_dimension = 0,
  277. and video_max_dimension = 0, keep the aspect ratio and resize
  278. the frame so that frame video_width is $video_width
  279. - When video_width != 0, video_height != 0, video_min_dimension = 0,
  280. and video_max_dimension = 0, resize the frame so that frame
  281. video_width and video_height are set to $video_width and
  282. $video_height, respectively
  283. video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
  284. video_timebase_numerator / video_timebase_denominator (float, optional): a rational
  285. number which denotes timebase in video stream
  286. read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
  287. audio_samples (int, optional): audio sampling rate
  288. audio_channels (int optional): audio audio_channels
  289. audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
  290. audio_timebase_numerator / audio_timebase_denominator (float, optional):
  291. a rational number which denotes time base in audio stream
  292. Returns:
  293. vframes (Tensor[T, H, W, C]): the `T` video frames
  294. aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
  295. `K` is the number of channels
  296. """
  297. _validate_pts(video_pts_range)
  298. _validate_pts(audio_pts_range)
  299. if not isinstance(video_data, torch.Tensor):
  300. with warnings.catch_warnings():
  301. # Ignore the warning because we actually don't modify the buffer in this function
  302. warnings.filterwarnings("ignore", message="The given buffer is not writable")
  303. video_data = torch.frombuffer(video_data, dtype=torch.uint8)
  304. result = torch.ops.video_reader.read_video_from_memory(
  305. video_data,
  306. seek_frame_margin,
  307. 0, # getPtsOnly
  308. read_video_stream,
  309. video_width,
  310. video_height,
  311. video_min_dimension,
  312. video_max_dimension,
  313. video_pts_range[0],
  314. video_pts_range[1],
  315. video_timebase_numerator,
  316. video_timebase_denominator,
  317. read_audio_stream,
  318. audio_samples,
  319. audio_channels,
  320. audio_pts_range[0],
  321. audio_pts_range[1],
  322. audio_timebase_numerator,
  323. audio_timebase_denominator,
  324. )
  325. vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
  326. if aframes.numel() > 0:
  327. # when audio stream is found
  328. aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
  329. return vframes, aframes
  330. def _read_video_timestamps_from_memory(
  331. video_data: torch.Tensor,
  332. ) -> Tuple[List[int], List[int], VideoMetaData]:
  333. """
  334. Decode all frames in the video. Only pts (presentation timestamp) is returned.
  335. The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
  336. is much faster than read_video(...)
  337. """
  338. if not isinstance(video_data, torch.Tensor):
  339. with warnings.catch_warnings():
  340. # Ignore the warning because we actually don't modify the buffer in this function
  341. warnings.filterwarnings("ignore", message="The given buffer is not writable")
  342. video_data = torch.frombuffer(video_data, dtype=torch.uint8)
  343. result = torch.ops.video_reader.read_video_from_memory(
  344. video_data,
  345. 0, # seek_frame_margin
  346. 1, # getPtsOnly
  347. 1, # read_video_stream
  348. 0, # video_width
  349. 0, # video_height
  350. 0, # video_min_dimension
  351. 0, # video_max_dimension
  352. 0, # video_start_pts
  353. -1, # video_end_pts
  354. 0, # video_timebase_num
  355. 1, # video_timebase_den
  356. 1, # read_audio_stream
  357. 0, # audio_samples
  358. 0, # audio_channels
  359. 0, # audio_start_pts
  360. -1, # audio_end_pts
  361. 0, # audio_timebase_num
  362. 1, # audio_timebase_den
  363. )
  364. _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
  365. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  366. vframe_pts = vframe_pts.numpy().tolist()
  367. aframe_pts = aframe_pts.numpy().tolist()
  368. return vframe_pts, aframe_pts, info
  369. def _probe_video_from_memory(
  370. video_data: torch.Tensor,
  371. ) -> VideoMetaData:
  372. """
  373. Probe a video in memory and return VideoMetaData with info about the video
  374. This function is torchscriptable
  375. """
  376. if not isinstance(video_data, torch.Tensor):
  377. with warnings.catch_warnings():
  378. # Ignore the warning because we actually don't modify the buffer in this function
  379. warnings.filterwarnings("ignore", message="The given buffer is not writable")
  380. video_data = torch.frombuffer(video_data, dtype=torch.uint8)
  381. result = torch.ops.video_reader.probe_video_from_memory(video_data)
  382. vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
  383. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  384. return info
  385. def _read_video(
  386. filename: str,
  387. start_pts: Union[float, Fraction] = 0,
  388. end_pts: Optional[Union[float, Fraction]] = None,
  389. pts_unit: str = "pts",
  390. ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
  391. if end_pts is None:
  392. end_pts = float("inf")
  393. if pts_unit == "pts":
  394. warnings.warn(
  395. "The pts_unit 'pts' gives wrong results and will be removed in a "
  396. + "follow-up version. Please use pts_unit 'sec'."
  397. )
  398. info = _probe_video_from_file(filename)
  399. has_video = info.has_video
  400. has_audio = info.has_audio
  401. def get_pts(time_base):
  402. start_offset = start_pts
  403. end_offset = end_pts
  404. if pts_unit == "sec":
  405. start_offset = int(math.floor(start_pts * (1 / time_base)))
  406. if end_offset != float("inf"):
  407. end_offset = int(math.ceil(end_pts * (1 / time_base)))
  408. if end_offset == float("inf"):
  409. end_offset = -1
  410. return start_offset, end_offset
  411. video_pts_range = (0, -1)
  412. video_timebase = default_timebase
  413. if has_video:
  414. video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
  415. video_pts_range = get_pts(video_timebase)
  416. audio_pts_range = (0, -1)
  417. audio_timebase = default_timebase
  418. if has_audio:
  419. audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
  420. audio_pts_range = get_pts(audio_timebase)
  421. vframes, aframes, info = _read_video_from_file(
  422. filename,
  423. read_video_stream=True,
  424. video_pts_range=video_pts_range,
  425. video_timebase=video_timebase,
  426. read_audio_stream=True,
  427. audio_pts_range=audio_pts_range,
  428. audio_timebase=audio_timebase,
  429. )
  430. _info = {}
  431. if has_video:
  432. _info["video_fps"] = info.video_fps
  433. if has_audio:
  434. _info["audio_fps"] = info.audio_sample_rate
  435. return vframes, aframes, _info
  436. def _read_video_timestamps(
  437. filename: str, pts_unit: str = "pts"
  438. ) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
  439. if pts_unit == "pts":
  440. warnings.warn(
  441. "The pts_unit 'pts' gives wrong results and will be removed in a "
  442. + "follow-up version. Please use pts_unit 'sec'."
  443. )
  444. pts: Union[List[int], List[Fraction]]
  445. pts, _, info = _read_video_timestamps_from_file(filename)
  446. if pts_unit == "sec":
  447. video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
  448. pts = [x * video_time_base for x in pts]
  449. video_fps = info.video_fps if info.has_video else None
  450. return pts, video_fps