plot_optical_flow.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. =====================================================
  3. Optical Flow: Predicting movement with the RAFT model
  4. =====================================================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_optical_flow.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_optical_flow.py>` to download the full example code.
  8. Optical flow is the task of predicting movement between two images, usually two
  9. consecutive frames of a video. Optical flow models take two images as input, and
  10. predict a flow: the flow indicates the displacement of every single pixel in the
  11. first image, and maps it to its corresponding pixel in the second image. Flows
  12. are (2, H, W)-dimensional tensors, where the first axis corresponds to the
  13. predicted horizontal and vertical displacements.
  14. The following example illustrates how torchvision can be used to predict flows
  15. using our implementation of the RAFT model. We will also see how to convert the
  16. predicted flows to RGB images for visualization.
  17. """
  18. import numpy as np
  19. import torch
  20. import matplotlib.pyplot as plt
  21. import torchvision.transforms.functional as F
  22. plt.rcParams["savefig.bbox"] = "tight"
  23. # sphinx_gallery_thumbnail_number = 2
  24. def plot(imgs, **imshow_kwargs):
  25. if not isinstance(imgs[0], list):
  26. # Make a 2d grid even if there's just 1 row
  27. imgs = [imgs]
  28. num_rows = len(imgs)
  29. num_cols = len(imgs[0])
  30. _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
  31. for row_idx, row in enumerate(imgs):
  32. for col_idx, img in enumerate(row):
  33. ax = axs[row_idx, col_idx]
  34. img = F.to_pil_image(img.to("cpu"))
  35. ax.imshow(np.asarray(img), **imshow_kwargs)
  36. ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  37. plt.tight_layout()
  38. # %%
  39. # Reading Videos Using Torchvision
  40. # --------------------------------
  41. # We will first read a video using :func:`~torchvision.io.read_video`.
  42. # Alternatively one can use the new :class:`~torchvision.io.VideoReader` API (if
  43. # torchvision is built from source).
  44. # The video we will use here is free of use from `pexels.com
  45. # <https://www.pexels.com/video/a-man-playing-a-game-of-basketball-5192157/>`_,
  46. # credits go to `Pavel Danilyuk <https://www.pexels.com/@pavel-danilyuk>`_.
  47. import tempfile
  48. from pathlib import Path
  49. from urllib.request import urlretrieve
  50. video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
  51. video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
  52. _ = urlretrieve(video_url, video_path)
  53. # %%
  54. # :func:`~torchvision.io.read_video` returns the video frames, audio frames and
  55. # the metadata associated with the video. In our case, we only need the video
  56. # frames.
  57. #
  58. # Here we will just make 2 predictions between 2 pre-selected pairs of frames,
  59. # namely frames (100, 101) and (150, 151). Each of these pairs corresponds to a
  60. # single model input.
  61. from torchvision.io import read_video
  62. frames, _, _ = read_video(str(video_path), output_format="TCHW")
  63. img1_batch = torch.stack([frames[100], frames[150]])
  64. img2_batch = torch.stack([frames[101], frames[151]])
  65. plot(img1_batch)
  66. # %%
  67. # The RAFT model accepts RGB images. We first get the frames from
  68. # :func:`~torchvision.io.read_video` and resize them to ensure their dimensions
  69. # are divisible by 8. Note that we explicitly use ``antialias=False``, because
  70. # this is how those models were trained. Then we use the transforms bundled into
  71. # the weights in order to preprocess the input and rescale its values to the
  72. # required ``[-1, 1]`` interval.
  73. from torchvision.models.optical_flow import Raft_Large_Weights
  74. weights = Raft_Large_Weights.DEFAULT
  75. transforms = weights.transforms()
  76. def preprocess(img1_batch, img2_batch):
  77. img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False)
  78. img2_batch = F.resize(img2_batch, size=[520, 960], antialias=False)
  79. return transforms(img1_batch, img2_batch)
  80. img1_batch, img2_batch = preprocess(img1_batch, img2_batch)
  81. print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}")
  82. # %%
  83. # Estimating Optical flow using RAFT
  84. # ----------------------------------
  85. # We will use our RAFT implementation from
  86. # :func:`~torchvision.models.optical_flow.raft_large`, which follows the same
  87. # architecture as the one described in the `original paper <https://arxiv.org/abs/2003.12039>`_.
  88. # We also provide the :func:`~torchvision.models.optical_flow.raft_small` model
  89. # builder, which is smaller and faster to run, sacrificing a bit of accuracy.
  90. from torchvision.models.optical_flow import raft_large
  91. # If you can, run this example on a GPU, it will be a lot faster.
  92. device = "cuda" if torch.cuda.is_available() else "cpu"
  93. model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
  94. model = model.eval()
  95. list_of_flows = model(img1_batch.to(device), img2_batch.to(device))
  96. print(f"type = {type(list_of_flows)}")
  97. print(f"length = {len(list_of_flows)} = number of iterations of the model")
  98. # %%
  99. # The RAFT model outputs lists of predicted flows where each entry is a
  100. # (N, 2, H, W) batch of predicted flows that corresponds to a given "iteration"
  101. # in the model. For more details on the iterative nature of the model, please
  102. # refer to the `original paper <https://arxiv.org/abs/2003.12039>`_. Here, we
  103. # are only interested in the final predicted flows (they are the most accurate
  104. # ones), so we will just retrieve the last item in the list.
  105. #
  106. # As described above, a flow is a tensor with dimensions (2, H, W) (or (N, 2, H,
  107. # W) for batches of flows) where each entry corresponds to the horizontal and
  108. # vertical displacement of each pixel from the first image to the second image.
  109. # Note that the predicted flows are in "pixel" unit, they are not normalized
  110. # w.r.t. the dimensions of the images.
  111. predicted_flows = list_of_flows[-1]
  112. print(f"dtype = {predicted_flows.dtype}")
  113. print(f"shape = {predicted_flows.shape} = (N, 2, H, W)")
  114. print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}")
  115. # %%
  116. # Visualizing predicted flows
  117. # ---------------------------
  118. # Torchvision provides the :func:`~torchvision.utils.flow_to_image` utility to
  119. # convert a flow into an RGB image. It also supports batches of flows.
  120. # each "direction" in the flow will be mapped to a given RGB color. In the
  121. # images below, pixels with similar colors are assumed by the model to be moving
  122. # in similar directions. The model is properly able to predict the movement of
  123. # the ball and the player. Note in particular the different predicted direction
  124. # of the ball in the first image (going to the left) and in the second image
  125. # (going up).
  126. from torchvision.utils import flow_to_image
  127. flow_imgs = flow_to_image(predicted_flows)
  128. # The images have been mapped into [-1, 1] but for plotting we want them in [0, 1]
  129. img1_batch = [(img1 + 1) / 2 for img1 in img1_batch]
  130. grid = [[img1, flow_img] for (img1, flow_img) in zip(img1_batch, flow_imgs)]
  131. plot(grid)
  132. # %%
  133. # Bonus: Creating GIFs of predicted flows
  134. # ---------------------------------------
  135. # In the example above we have only shown the predicted flows of 2 pairs of
  136. # frames. A fun way to apply the Optical Flow models is to run the model on an
  137. # entire video, and create a new video from all the predicted flows. Below is a
  138. # snippet that can get you started with this. We comment out the code, because
  139. # this example is being rendered on a machine without a GPU, and it would take
  140. # too long to run it.
  141. # from torchvision.io import write_jpeg
  142. # for i, (img1, img2) in enumerate(zip(frames, frames[1:])):
  143. # # Note: it would be faster to predict batches of flows instead of individual flows
  144. # img1, img2 = preprocess(img1, img2)
  145. # list_of_flows = model(img1.to(device), img2.to(device))
  146. # predicted_flow = list_of_flows[-1][0]
  147. # flow_img = flow_to_image(predicted_flow).to("cpu")
  148. # output_folder = "/tmp/" # Update this to the folder of your choice
  149. # write_jpeg(flow_img, output_folder + f"predicted_flow_{i}.jpg")
  150. # %%
  151. # Once the .jpg flow images are saved, you can convert them into a video or a
  152. # GIF using ffmpeg with e.g.:
  153. #
  154. # ffmpeg -f image2 -framerate 30 -i predicted_flow_%d.jpg -loop -1 flow.gif