plot_scripted_tensor_transforms.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """
  2. ===================
  3. Torchscript support
  4. ===================
  5. .. note::
  6. Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_
  7. or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code.
  8. This example illustrates `torchscript
  9. <https://pytorch.org/docs/stable/jit.html>`_ support of the torchvision
  10. :ref:`transforms <transforms>` on Tensor images.
  11. """
  12. # %%
  13. from pathlib import Path
  14. import matplotlib.pyplot as plt
  15. import torch
  16. import torch.nn as nn
  17. import torchvision.transforms as v1
  18. from torchvision.io import read_image
  19. plt.rcParams["savefig.bbox"] = 'tight'
  20. torch.manual_seed(1)
  21. # If you're trying to run that on collab, you can download the assets and the
  22. # helpers from https://github.com/pytorch/vision/tree/main/gallery/
  23. import sys
  24. sys.path += ["../transforms"]
  25. from helpers import plot
  26. ASSETS_PATH = Path('../assets')
  27. # %%
  28. # Most transforms support torchscript. For composing transforms, we use
  29. # :class:`torch.nn.Sequential` instead of
  30. # :class:`~torchvision.transforms.v2.Compose`:
  31. dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
  32. dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
  33. transforms = torch.nn.Sequential(
  34. v1.RandomCrop(224),
  35. v1.RandomHorizontalFlip(p=0.3),
  36. )
  37. scripted_transforms = torch.jit.script(transforms)
  38. plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])
  39. # %%
  40. # .. warning::
  41. #
  42. # Above we have used transforms from the ``torchvision.transforms``
  43. # namespace, i.e. the "v1" transforms. The v2 transforms from the
  44. # ``torchvision.transforms.v2`` namespace are the :ref:`recommended
  45. # <v1_or_v2>` way to use transforms in your code.
  46. #
  47. # The v2 transforms also support torchscript, but if you call
  48. # ``torch.jit.script()`` on a v2 **class** transform, you'll actually end up
  49. # with its (scripted) v1 equivalent. This may lead to slightly different
  50. # results between the scripted and eager executions due to implementation
  51. # differences between v1 and v2.
  52. #
  53. # If you really need torchscript support for the v2 transforms, **we
  54. # recommend scripting the functionals** from the
  55. # ``torchvision.transforms.v2.functional`` namespace to avoid surprises.
  56. #
  57. # Below we now show how to combine image transformations and a model forward
  58. # pass, while using ``torch.jit.script`` to obtain a single scripted module.
  59. #
  60. # Let's define a ``Predictor`` module that transforms the input tensor and then
  61. # applies an ImageNet model on it.
  62. from torchvision.models import resnet18, ResNet18_Weights
  63. class Predictor(nn.Module):
  64. def __init__(self):
  65. super().__init__()
  66. weights = ResNet18_Weights.DEFAULT
  67. self.resnet18 = resnet18(weights=weights, progress=False).eval()
  68. self.transforms = weights.transforms(antialias=True)
  69. def forward(self, x: torch.Tensor) -> torch.Tensor:
  70. with torch.no_grad():
  71. x = self.transforms(x)
  72. y_pred = self.resnet18(x)
  73. return y_pred.argmax(dim=1)
  74. # %%
  75. # Now, let's define scripted and non-scripted instances of ``Predictor`` and
  76. # apply it on multiple tensor images of the same size
  77. device = "cuda" if torch.cuda.is_available() else "cpu"
  78. predictor = Predictor().to(device)
  79. scripted_predictor = torch.jit.script(predictor).to(device)
  80. batch = torch.stack([dog1, dog2]).to(device)
  81. res = predictor(batch)
  82. res_scripted = scripted_predictor(batch)
  83. # %%
  84. # We can verify that the prediction of the scripted and non-scripted models are
  85. # the same:
  86. import json
  87. with open(Path('../assets') / 'imagenet_class_index.json') as labels_file:
  88. labels = json.load(labels_file)
  89. for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
  90. assert pred == pred_scripted
  91. print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")
  92. # %%
  93. # Since the model is scripted, it can be easily dumped on disk and re-used
  94. import tempfile
  95. with tempfile.NamedTemporaryFile() as f:
  96. scripted_predictor.save(f.name)
  97. dumped_scripted_predictor = torch.jit.load(f.name)
  98. res_scripted_dumped = dumped_scripted_predictor(batch)
  99. assert (res_scripted_dumped == res_scripted).all()
  100. # %%