plot_video_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. """
  2. =========
  3. Video API
  4. =========
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_video_api.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_video_api.py>` to download the full example code.
  8. This example illustrates some of the APIs that torchvision offers for
  9. videos, together with the examples on how to build datasets and more.
  10. """
  11. # %%
  12. # 1. Introduction: building a new video object and examining the properties
  13. # -------------------------------------------------------------------------
  14. # First we select a video to test the object out. For the sake of argument
  15. # we're using one from kinetics400 dataset.
  16. # To create it, we need to define the path and the stream we want to use.
  17. # %%
  18. # Chosen video statistics:
  19. #
  20. # - WUzgd7C1pWA.mp4
  21. # - source:
  22. # - kinetics-400
  23. # - video:
  24. # - H-264
  25. # - MPEG-4 AVC (part 10) (avc1)
  26. # - fps: 29.97
  27. # - audio:
  28. # - MPEG AAC audio (mp4a)
  29. # - sample rate: 48K Hz
  30. #
  31. import torch
  32. import torchvision
  33. from torchvision.datasets.utils import download_url
  34. torchvision.set_video_backend("video_reader")
  35. # Download the sample video
  36. download_url(
  37. "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
  38. ".",
  39. "WUzgd7C1pWA.mp4"
  40. )
  41. video_path = "./WUzgd7C1pWA.mp4"
  42. # %%
  43. # Streams are defined in a similar fashion as torch devices. We encode them as strings in a form
  44. # of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int.
  45. # The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered.
  46. # Firstly, let's get the metadata for our particular video:
  47. stream = "video"
  48. video = torchvision.io.VideoReader(video_path, stream)
  49. video.get_metadata()
  50. # %%
  51. # Here we can see that video has two streams - a video and an audio stream.
  52. # Currently available stream types include ['video', 'audio'].
  53. # Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id
  54. # (which are determined by video encoding).
  55. # In this way, if the video container contains multiple streams of the same type,
  56. # users can access the one they want.
  57. # If only stream type is passed, the decoder auto-detects first stream of that type and returns it.
  58. # %%
  59. # Let's read all the frames from the video stream. By default, the return value of
  60. # ``next(video_reader)`` is a dict containing the following fields.
  61. #
  62. # The return fields are:
  63. #
  64. # - ``data``: containing a torch.tensor
  65. # - ``pts``: containing a float timestamp of this particular frame
  66. metadata = video.get_metadata()
  67. video.set_current_stream("audio")
  68. frames = [] # we are going to save the frames here.
  69. ptss = [] # pts is a presentation timestamp in seconds (float) of each frame
  70. for frame in video:
  71. frames.append(frame['data'])
  72. ptss.append(frame['pts'])
  73. print("PTS for first five frames ", ptss[:5])
  74. print("Total number of frames: ", len(frames))
  75. approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
  76. print("Approx total number of datapoints we can expect: ", approx_nf)
  77. print("Read data size: ", frames[0].size(0) * len(frames))
  78. # %%
  79. # But what if we only want to read certain time segment of the video?
  80. # That can be done easily using the combination of our ``seek`` function, and the fact that each call
  81. # to next returns the presentation timestamp of the returned frame in seconds.
  82. #
  83. # Given that our implementation relies on python iterators,
  84. # we can leverage itertools to simplify the process and make it more pythonic.
  85. #
  86. # For example, if we wanted to read ten frames from second second:
  87. import itertools
  88. video.set_current_stream("video")
  89. frames = [] # we are going to save the frames here.
  90. # We seek into a second second of the video and use islice to get 10 frames since
  91. for frame, pts in itertools.islice(video.seek(2), 10):
  92. frames.append(frame)
  93. print("Total number of frames: ", len(frames))
  94. # %%
  95. # Or if we wanted to read from 2nd to 5th second,
  96. # We seek into a second second of the video,
  97. # then we utilize the itertools takewhile to get the
  98. # correct number of frames:
  99. video.set_current_stream("video")
  100. frames = [] # we are going to save the frames here.
  101. video = video.seek(2)
  102. for frame in itertools.takewhile(lambda x: x['pts'] <= 5, video):
  103. frames.append(frame['data'])
  104. print("Total number of frames: ", len(frames))
  105. approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0]
  106. print("We can expect approx: ", approx_nf)
  107. print("Tensor size: ", frames[0].size())
  108. # %%
  109. # 2. Building a sample read_video function
  110. # ----------------------------------------------------------------------------------------
  111. # We can utilize the methods above to build the read video function that follows
  112. # the same API to the existing ``read_video`` function.
  113. def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):
  114. if end is None:
  115. end = float("inf")
  116. if end < start:
  117. raise ValueError(
  118. "end time should be larger than start time, got "
  119. f"start time={start} and end time={end}"
  120. )
  121. video_frames = torch.empty(0)
  122. video_pts = []
  123. if read_video:
  124. video_object.set_current_stream("video")
  125. frames = []
  126. for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
  127. frames.append(frame['data'])
  128. video_pts.append(frame['pts'])
  129. if len(frames) > 0:
  130. video_frames = torch.stack(frames, 0)
  131. audio_frames = torch.empty(0)
  132. audio_pts = []
  133. if read_audio:
  134. video_object.set_current_stream("audio")
  135. frames = []
  136. for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
  137. frames.append(frame['data'])
  138. audio_pts.append(frame['pts'])
  139. if len(frames) > 0:
  140. audio_frames = torch.cat(frames, 0)
  141. return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()
  142. # Total number of frames should be 327 for video and 523264 datapoints for audio
  143. vf, af, info, meta = example_read_video(video)
  144. print(vf.size(), af.size())
  145. # %%
  146. # 3. Building an example randomly sampled dataset (can be applied to training dataset of kinetics400)
  147. # -------------------------------------------------------------------------------------------------------
  148. # Cool, so now we can use the same principle to make the sample dataset.
  149. # We suggest trying out iterable dataset for this purpose.
  150. # Here, we are going to build an example dataset that reads randomly selected 10 frames of video.
  151. # %%
  152. # Make sample dataset
  153. import os
  154. os.makedirs("./dataset", exist_ok=True)
  155. os.makedirs("./dataset/1", exist_ok=True)
  156. os.makedirs("./dataset/2", exist_ok=True)
  157. # %%
  158. # Download the videos
  159. from torchvision.datasets.utils import download_url
  160. download_url(
  161. "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
  162. "./dataset/1", "WUzgd7C1pWA.mp4"
  163. )
  164. download_url(
  165. "https://github.com/pytorch/vision/blob/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true",
  166. "./dataset/1",
  167. "RATRACE_wave_f_nm_np1_fr_goo_37.avi"
  168. )
  169. download_url(
  170. "https://github.com/pytorch/vision/blob/main/test/assets/videos/SOX5yA1l24A.mp4?raw=true",
  171. "./dataset/2",
  172. "SOX5yA1l24A.mp4"
  173. )
  174. download_url(
  175. "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true",
  176. "./dataset/2",
  177. "v_SoccerJuggling_g23_c01.avi"
  178. )
  179. download_url(
  180. "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true",
  181. "./dataset/2",
  182. "v_SoccerJuggling_g24_c01.avi"
  183. )
  184. # %%
  185. # Housekeeping and utilities
  186. import os
  187. import random
  188. from torchvision.datasets.folder import make_dataset
  189. from torchvision import transforms as t
  190. def _find_classes(dir):
  191. classes = [d.name for d in os.scandir(dir) if d.is_dir()]
  192. classes.sort()
  193. class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
  194. return classes, class_to_idx
  195. def get_samples(root, extensions=(".mp4", ".avi")):
  196. _, class_to_idx = _find_classes(root)
  197. return make_dataset(root, class_to_idx, extensions=extensions)
  198. # %%
  199. # We are going to define the dataset and some basic arguments.
  200. # We assume the structure of the FolderDataset, and add the following parameters:
  201. #
  202. # - ``clip_len``: length of a clip in frames
  203. # - ``frame_transform``: transform for every frame individually
  204. # - ``video_transform``: transform on a video sequence
  205. #
  206. # .. note::
  207. # We actually add epoch size as using :func:`~torch.utils.data.IterableDataset`
  208. # class allows us to naturally oversample clips or images from each video if needed.
  209. class RandomDataset(torch.utils.data.IterableDataset):
  210. def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
  211. super(RandomDataset).__init__()
  212. self.samples = get_samples(root)
  213. # Allow for temporal jittering
  214. if epoch_size is None:
  215. epoch_size = len(self.samples)
  216. self.epoch_size = epoch_size
  217. self.clip_len = clip_len
  218. self.frame_transform = frame_transform
  219. self.video_transform = video_transform
  220. def __iter__(self):
  221. for i in range(self.epoch_size):
  222. # Get random sample
  223. path, target = random.choice(self.samples)
  224. # Get video object
  225. vid = torchvision.io.VideoReader(path, "video")
  226. metadata = vid.get_metadata()
  227. video_frames = [] # video frame buffer
  228. # Seek and return frames
  229. max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
  230. start = random.uniform(0., max_seek)
  231. for frame in itertools.islice(vid.seek(start), self.clip_len):
  232. video_frames.append(self.frame_transform(frame['data']))
  233. current_pts = frame['pts']
  234. # Stack it into a tensor
  235. video = torch.stack(video_frames, 0)
  236. if self.video_transform:
  237. video = self.video_transform(video)
  238. output = {
  239. 'path': path,
  240. 'video': video,
  241. 'target': target,
  242. 'start': start,
  243. 'end': current_pts}
  244. yield output
  245. # %%
  246. # Given a path of videos in a folder structure, i.e:
  247. #
  248. # - dataset
  249. # - class 1
  250. # - file 0
  251. # - file 1
  252. # - ...
  253. # - class 2
  254. # - file 0
  255. # - file 1
  256. # - ...
  257. # - ...
  258. #
  259. # We can generate a dataloader and test the dataset.
  260. transforms = [t.Resize((112, 112))]
  261. frame_transform = t.Compose(transforms)
  262. dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform)
  263. # %%
  264. from torch.utils.data import DataLoader
  265. loader = DataLoader(dataset, batch_size=12)
  266. data = {"video": [], 'start': [], 'end': [], 'tensorsize': []}
  267. for batch in loader:
  268. for i in range(len(batch['path'])):
  269. data['video'].append(batch['path'][i])
  270. data['start'].append(batch['start'][i].item())
  271. data['end'].append(batch['end'][i].item())
  272. data['tensorsize'].append(batch['video'][i].size())
  273. print(data)
  274. # %%
  275. # 4. Data Visualization
  276. # ----------------------------------
  277. # Example of visualized video
  278. import matplotlib.pyplot as plt
  279. plt.figure(figsize=(12, 12))
  280. for i in range(16):
  281. plt.subplot(4, 4, i + 1)
  282. plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0))
  283. plt.axis("off")
  284. # %%
  285. # Cleanup the video and dataset:
  286. import os
  287. import shutil
  288. os.remove("./WUzgd7C1pWA.mp4")
  289. shutil.rmtree("./dataset")