models.rst 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. .. _models:
  2. Models and pre-trained weights
  3. ##############################
  4. The ``torchvision.models`` subpackage contains definitions of models for addressing
  5. different tasks, including: image classification, pixelwise semantic
  6. segmentation, object detection, instance segmentation, person
  7. keypoint detection, video classification, and optical flow.
  8. General information on pre-trained weights
  9. ==========================================
  10. TorchVision offers pre-trained weights for every provided architecture, using
  11. the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
  12. weights to a cache directory. This directory can be set using the `TORCH_HOME`
  13. environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.
  14. .. note::
  15. The pre-trained models provided in this library may have their own licenses or
  16. terms and conditions derived from the dataset used for training. It is your
  17. responsibility to determine whether you have permission to use the models for
  18. your use case.
  19. .. note ::
  20. Backward compatibility is guaranteed for loading a serialized
  21. ``state_dict`` to the model created using old PyTorch version.
  22. On the contrary, loading entire saved models or serialized
  23. ``ScriptModules`` (serialized using older versions of PyTorch)
  24. may not preserve the historic behaviour. Refer to the following
  25. `documentation
  26. <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
  27. Initializing pre-trained models
  28. -------------------------------
  29. As of v0.13, TorchVision offers a new `Multi-weight support API
  30. <https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
  31. for loading different weights to the existing model builder methods:
  32. .. code:: python
  33. from torchvision.models import resnet50, ResNet50_Weights
  34. # Old weights with accuracy 76.130%
  35. resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
  36. # New weights with accuracy 80.858%
  37. resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  38. # Best available weights (currently alias for IMAGENET1K_V2)
  39. # Note that these weights may change across versions
  40. resnet50(weights=ResNet50_Weights.DEFAULT)
  41. # Strings are also supported
  42. resnet50(weights="IMAGENET1K_V2")
  43. # No weights - random initialization
  44. resnet50(weights=None)
  45. Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
  46. .. code:: python
  47. from torchvision.models import resnet50, ResNet50_Weights
  48. # Using pretrained weights:
  49. resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
  50. resnet50(weights="IMAGENET1K_V1")
  51. resnet50(pretrained=True) # deprecated
  52. resnet50(True) # deprecated
  53. # Using no weights:
  54. resnet50(weights=None)
  55. resnet50()
  56. resnet50(pretrained=False) # deprecated
  57. resnet50(False) # deprecated
  58. Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
  59. Using the pre-trained models
  60. ----------------------------
  61. Before using the pre-trained models, one must preprocess the image
  62. (resize with right resolution/interpolation, apply inference transforms,
  63. rescale the values etc). There is no standard way to do this as it depends on
  64. how a given model was trained. It can vary across model families, variants or
  65. even weight versions. Using the correct preprocessing method is critical and
  66. failing to do so may lead to decreased accuracy or incorrect outputs.
  67. All the necessary information for the inference transforms of each pre-trained
  68. model is provided on its weights documentation. To simplify inference, TorchVision
  69. bundles the necessary preprocessing transforms into each model weight. These are
  70. accessible via the ``weight.transforms`` attribute:
  71. .. code:: python
  72. # Initialize the Weight Transforms
  73. weights = ResNet50_Weights.DEFAULT
  74. preprocess = weights.transforms()
  75. # Apply it to the input image
  76. img_transformed = preprocess(img)
  77. Some models use modules which have different training and evaluation
  78. behavior, such as batch normalization. To switch between these modes, use
  79. ``model.train()`` or ``model.eval()`` as appropriate. See
  80. :meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
  81. .. code:: python
  82. # Initialize model
  83. weights = ResNet50_Weights.DEFAULT
  84. model = resnet50(weights=weights)
  85. # Set model to eval mode
  86. model.eval()
  87. Listing and retrieving available models
  88. ---------------------------------------
  89. As of v0.14, TorchVision offers a new mechanism which allows listing and
  90. retrieving models and weights by their names. Here are a few examples on how to
  91. use them:
  92. .. code:: python
  93. # List available models
  94. all_models = list_models()
  95. classification_models = list_models(module=torchvision.models)
  96. # Initialize models
  97. m1 = get_model("mobilenet_v3_large", weights=None)
  98. m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
  99. # Fetch weights
  100. weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
  101. assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
  102. weights_enum = get_model_weights("quantized_mobilenet_v3_large")
  103. assert weights_enum == MobileNet_V3_Large_QuantizedWeights
  104. weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
  105. assert weights_enum == weights_enum2
  106. Here are the available public functions to retrieve models and their corresponding weights:
  107. .. currentmodule:: torchvision.models
  108. .. autosummary::
  109. :toctree: generated/
  110. :template: function.rst
  111. get_model
  112. get_model_weights
  113. get_weight
  114. list_models
  115. Using models from Hub
  116. ---------------------
  117. Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
  118. .. code:: python
  119. import torch
  120. # Option 1: passing weights param as string
  121. model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
  122. # Option 2: passing weights param as enum
  123. weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
  124. model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
  125. You can also retrieve all the available weights of a specific model via PyTorch Hub by doing:
  126. .. code:: python
  127. import torch
  128. weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
  129. print([weight for weight in weight_enum])
  130. The only exception to the above are the detection models included on
  131. :mod:`torchvision.models.detection`. These models require TorchVision
  132. to be installed because they depend on custom C++ operators.
  133. Classification
  134. ==============
  135. .. currentmodule:: torchvision.models
  136. The following classification models are available, with or without pre-trained
  137. weights:
  138. .. toctree::
  139. :maxdepth: 1
  140. models/alexnet
  141. models/convnext
  142. models/densenet
  143. models/efficientnet
  144. models/efficientnetv2
  145. models/googlenet
  146. models/inception
  147. models/maxvit
  148. models/mnasnet
  149. models/mobilenetv2
  150. models/mobilenetv3
  151. models/regnet
  152. models/resnet
  153. models/resnext
  154. models/shufflenetv2
  155. models/squeezenet
  156. models/swin_transformer
  157. models/vgg
  158. models/vision_transformer
  159. models/wide_resnet
  160. |
  161. Here is an example of how to use the pre-trained image classification models:
  162. .. code:: python
  163. from torchvision.io import read_image
  164. from torchvision.models import resnet50, ResNet50_Weights
  165. img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
  166. # Step 1: Initialize model with the best available weights
  167. weights = ResNet50_Weights.DEFAULT
  168. model = resnet50(weights=weights)
  169. model.eval()
  170. # Step 2: Initialize the inference transforms
  171. preprocess = weights.transforms()
  172. # Step 3: Apply inference preprocessing transforms
  173. batch = preprocess(img).unsqueeze(0)
  174. # Step 4: Use the model and print the predicted category
  175. prediction = model(batch).squeeze(0).softmax(0)
  176. class_id = prediction.argmax().item()
  177. score = prediction[class_id].item()
  178. category_name = weights.meta["categories"][class_id]
  179. print(f"{category_name}: {100 * score:.1f}%")
  180. The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
  181. Table of all available classification weights
  182. ---------------------------------------------
  183. Accuracies are reported on ImageNet-1K using single crops:
  184. .. include:: generated/classification_table.rst
  185. Quantized models
  186. ----------------
  187. .. currentmodule:: torchvision.models.quantization
  188. The following architectures provide support for INT8 quantized models, with or without
  189. pre-trained weights:
  190. .. toctree::
  191. :maxdepth: 1
  192. models/googlenet_quant
  193. models/inception_quant
  194. models/mobilenetv2_quant
  195. models/mobilenetv3_quant
  196. models/resnet_quant
  197. models/resnext_quant
  198. models/shufflenetv2_quant
  199. |
  200. Here is an example of how to use the pre-trained quantized image classification models:
  201. .. code:: python
  202. from torchvision.io import read_image
  203. from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
  204. img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
  205. # Step 1: Initialize model with the best available weights
  206. weights = ResNet50_QuantizedWeights.DEFAULT
  207. model = resnet50(weights=weights, quantize=True)
  208. model.eval()
  209. # Step 2: Initialize the inference transforms
  210. preprocess = weights.transforms()
  211. # Step 3: Apply inference preprocessing transforms
  212. batch = preprocess(img).unsqueeze(0)
  213. # Step 4: Use the model and print the predicted category
  214. prediction = model(batch).squeeze(0).softmax(0)
  215. class_id = prediction.argmax().item()
  216. score = prediction[class_id].item()
  217. category_name = weights.meta["categories"][class_id]
  218. print(f"{category_name}: {100 * score}%")
  219. The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
  220. Table of all available quantized classification weights
  221. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  222. Accuracies are reported on ImageNet-1K using single crops:
  223. .. include:: generated/classification_quant_table.rst
  224. Semantic Segmentation
  225. =====================
  226. .. currentmodule:: torchvision.models.segmentation
  227. .. betastatus:: segmentation module
  228. The following semantic segmentation models are available, with or without
  229. pre-trained weights:
  230. .. toctree::
  231. :maxdepth: 1
  232. models/deeplabv3
  233. models/fcn
  234. models/lraspp
  235. |
  236. Here is an example of how to use the pre-trained semantic segmentation models:
  237. .. code:: python
  238. from torchvision.io.image import read_image
  239. from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
  240. from torchvision.transforms.functional import to_pil_image
  241. img = read_image("gallery/assets/dog1.jpg")
  242. # Step 1: Initialize model with the best available weights
  243. weights = FCN_ResNet50_Weights.DEFAULT
  244. model = fcn_resnet50(weights=weights)
  245. model.eval()
  246. # Step 2: Initialize the inference transforms
  247. preprocess = weights.transforms()
  248. # Step 3: Apply inference preprocessing transforms
  249. batch = preprocess(img).unsqueeze(0)
  250. # Step 4: Use the model and visualize the prediction
  251. prediction = model(batch)["out"]
  252. normalized_masks = prediction.softmax(dim=1)
  253. class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
  254. mask = normalized_masks[0, class_to_idx["dog"]]
  255. to_pil_image(mask).show()
  256. The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
  257. The output format of the models is illustrated in :ref:`semantic_seg_output`.
  258. Table of all available semantic segmentation weights
  259. ----------------------------------------------------
  260. All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:
  261. .. include:: generated/segmentation_table.rst
  262. .. _object_det_inst_seg_pers_keypoint_det:
  263. Object Detection, Instance Segmentation and Person Keypoint Detection
  264. =====================================================================
  265. The pre-trained models for detection, instance segmentation and
  266. keypoint detection are initialized with the classification models
  267. in torchvision. The models expect a list of ``Tensor[C, H, W]``.
  268. Check the constructor of the models for more information.
  269. .. betastatus:: detection module
  270. Object Detection
  271. ----------------
  272. .. currentmodule:: torchvision.models.detection
  273. The following object detection models are available, with or without pre-trained
  274. weights:
  275. .. toctree::
  276. :maxdepth: 1
  277. models/faster_rcnn
  278. models/fcos
  279. models/retinanet
  280. models/ssd
  281. models/ssdlite
  282. |
  283. Here is an example of how to use the pre-trained object detection models:
  284. .. code:: python
  285. from torchvision.io.image import read_image
  286. from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
  287. from torchvision.utils import draw_bounding_boxes
  288. from torchvision.transforms.functional import to_pil_image
  289. img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
  290. # Step 1: Initialize model with the best available weights
  291. weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
  292. model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
  293. model.eval()
  294. # Step 2: Initialize the inference transforms
  295. preprocess = weights.transforms()
  296. # Step 3: Apply inference preprocessing transforms
  297. batch = [preprocess(img)]
  298. # Step 4: Use the model and visualize the prediction
  299. prediction = model(batch)[0]
  300. labels = [weights.meta["categories"][i] for i in prediction["labels"]]
  301. box = draw_bounding_boxes(img, boxes=prediction["boxes"],
  302. labels=labels,
  303. colors="red",
  304. width=4, font_size=30)
  305. im = to_pil_image(box.detach())
  306. im.show()
  307. The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
  308. For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
  309. Table of all available Object detection weights
  310. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  311. Box MAPs are reported on COCO val2017:
  312. .. include:: generated/detection_table.rst
  313. Instance Segmentation
  314. ---------------------
  315. .. currentmodule:: torchvision.models.detection
  316. The following instance segmentation models are available, with or without pre-trained
  317. weights:
  318. .. toctree::
  319. :maxdepth: 1
  320. models/mask_rcnn
  321. |
  322. For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
  323. Table of all available Instance segmentation weights
  324. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  325. Box and Mask MAPs are reported on COCO val2017:
  326. .. include:: generated/instance_segmentation_table.rst
  327. Keypoint Detection
  328. ------------------
  329. .. currentmodule:: torchvision.models.detection
  330. The following person keypoint detection models are available, with or without
  331. pre-trained weights:
  332. .. toctree::
  333. :maxdepth: 1
  334. models/keypoint_rcnn
  335. |
  336. The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
  337. For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
  338. Table of all available Keypoint detection weights
  339. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  340. Box and Keypoint MAPs are reported on COCO val2017:
  341. .. include:: generated/detection_keypoint_table.rst
  342. Video Classification
  343. ====================
  344. .. currentmodule:: torchvision.models.video
  345. .. betastatus:: video module
  346. The following video classification models are available, with or without
  347. pre-trained weights:
  348. .. toctree::
  349. :maxdepth: 1
  350. models/video_mvit
  351. models/video_resnet
  352. models/video_s3d
  353. models/video_swin_transformer
  354. |
  355. Here is an example of how to use the pre-trained video classification models:
  356. .. code:: python
  357. from torchvision.io.video import read_video
  358. from torchvision.models.video import r3d_18, R3D_18_Weights
  359. vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
  360. vid = vid[:32] # optionally shorten duration
  361. # Step 1: Initialize model with the best available weights
  362. weights = R3D_18_Weights.DEFAULT
  363. model = r3d_18(weights=weights)
  364. model.eval()
  365. # Step 2: Initialize the inference transforms
  366. preprocess = weights.transforms()
  367. # Step 3: Apply inference preprocessing transforms
  368. batch = preprocess(vid).unsqueeze(0)
  369. # Step 4: Use the model and print the predicted category
  370. prediction = model(batch).squeeze(0).softmax(0)
  371. label = prediction.argmax().item()
  372. score = prediction[label].item()
  373. category_name = weights.meta["categories"][label]
  374. print(f"{category_name}: {100 * score}%")
  375. The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
  376. Table of all available video classification weights
  377. ---------------------------------------------------
  378. Accuracies are reported on Kinetics-400 using single crops for clip length 16:
  379. .. include:: generated/video_table.rst
  380. Optical Flow
  381. ============
  382. .. currentmodule:: torchvision.models.optical_flow
  383. The following Optical Flow models are available, with or without pre-trained
  384. .. toctree::
  385. :maxdepth: 1
  386. models/raft