feature_extraction.rst 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. Feature extraction for model inspection
  2. =======================================
  3. .. currentmodule:: torchvision.models.feature_extraction
  4. The ``torchvision.models.feature_extraction`` package contains
  5. feature extraction utilities that let us tap into our models to access intermediate
  6. transformations of our inputs. This could be useful for a variety of
  7. applications in computer vision. Just a few examples are:
  8. - Visualizing feature maps.
  9. - Extracting features to compute image descriptors for tasks like facial
  10. recognition, copy-detection, or image retrieval.
  11. - Passing selected features to downstream sub-networks for end-to-end training
  12. with a specific task in mind. For example, passing a hierarchy of features
  13. to a Feature Pyramid Network with object detection heads.
  14. Torchvision provides :func:`create_feature_extractor` for this purpose.
  15. It works by following roughly these steps:
  16. 1. Symbolically tracing the model to get a graphical representation of
  17. how it transforms the input, step by step.
  18. 2. Setting the user-selected graph nodes as outputs.
  19. 3. Removing all redundant nodes (anything downstream of the output nodes).
  20. 4. Generating python code from the resulting graph and bundling that into a
  21. PyTorch module together with the graph itself.
  22. |
  23. The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
  24. provides a more general and detailed explanation of the above procedure and
  25. the inner workings of the symbolic tracing.
  26. .. _about-node-names:
  27. **About Node Names**
  28. In order to specify which nodes should be output nodes for extracted
  29. features, one should be familiar with the node naming convention used here
  30. (which differs slightly from that used in ``torch.fx``). A node name is
  31. specified as a ``.`` separated path walking the module hierarchy from top level
  32. module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
  33. in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
  34. layer of the ``ResNet`` module. Here are some finer points to keep in mind:
  35. - When specifying node names for :func:`create_feature_extractor`, you may
  36. provide a truncated version of a node name as a shortcut. To see how this
  37. works, try creating a ResNet-50 model and printing the node names with
  38. ``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
  39. observe that the last node pertaining to ``layer4`` is
  40. ``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
  41. node, or just ``"layer4"`` as this, by convention, refers to the last node
  42. (in order of execution) of ``layer4``.
  43. - If a certain module or operation is repeated more than once, node names get
  44. an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
  45. addition (``+``) operation is used three times in the same ``forward``
  46. method. Then there would be ``"path.to.module.add"``,
  47. ``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
  48. maintained within the scope of the direct parent. So in ResNet-50 there is
  49. a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
  50. operations reside in different blocks, there is no need for a postfix to
  51. disambiguate.
  52. **An Example**
  53. Here is an example of how we might extract features for MaskRCNN:
  54. .. code-block:: python
  55. import torch
  56. from torchvision.models import resnet50
  57. from torchvision.models.feature_extraction import get_graph_node_names
  58. from torchvision.models.feature_extraction import create_feature_extractor
  59. from torchvision.models.detection.mask_rcnn import MaskRCNN
  60. from torchvision.models.detection.backbone_utils import LastLevelMaxPool
  61. from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
  62. # To assist you in designing the feature extractor you may want to print out
  63. # the available nodes for resnet50.
  64. m = resnet50()
  65. train_nodes, eval_nodes = get_graph_node_names(resnet50())
  66. # The lists returned, are the names of all the graph nodes (in order of
  67. # execution) for the input model traced in train mode and in eval mode
  68. # respectively. You'll find that `train_nodes` and `eval_nodes` are the same
  69. # for this example. But if the model contains control flow that's dependent
  70. # on the training mode, they may be different.
  71. # To specify the nodes you want to extract, you could select the final node
  72. # that appears in each of the main layers:
  73. return_nodes = {
  74. # node_name: user-specified key for output dict
  75. 'layer1.2.relu_2': 'layer1',
  76. 'layer2.3.relu_2': 'layer2',
  77. 'layer3.5.relu_2': 'layer3',
  78. 'layer4.2.relu_2': 'layer4',
  79. }
  80. # But `create_feature_extractor` can also accept truncated node specifications
  81. # like "layer1", as it will just pick the last node that's a descendent of
  82. # of the specification. (Tip: be careful with this, especially when a layer
  83. # has multiple outputs. It's not always guaranteed that the last operation
  84. # performed is the one that corresponds to the output you desire. You should
  85. # consult the source code for the input model to confirm.)
  86. return_nodes = {
  87. 'layer1': 'layer1',
  88. 'layer2': 'layer2',
  89. 'layer3': 'layer3',
  90. 'layer4': 'layer4',
  91. }
  92. # Now you can build the feature extractor. This returns a module whose forward
  93. # method returns a dictionary like:
  94. # {
  95. # 'layer1': output of layer 1,
  96. # 'layer2': output of layer 2,
  97. # 'layer3': output of layer 3,
  98. # 'layer4': output of layer 4,
  99. # }
  100. create_feature_extractor(m, return_nodes=return_nodes)
  101. # Let's put all that together to wrap resnet50 with MaskRCNN
  102. # MaskRCNN requires a backbone with an attached FPN
  103. class Resnet50WithFPN(torch.nn.Module):
  104. def __init__(self):
  105. super(Resnet50WithFPN, self).__init__()
  106. # Get a resnet50 backbone
  107. m = resnet50()
  108. # Extract 4 main layers (note: MaskRCNN needs this particular name
  109. # mapping for return nodes)
  110. self.body = create_feature_extractor(
  111. m, return_nodes={f'layer{k}': str(v)
  112. for v, k in enumerate([1, 2, 3, 4])})
  113. # Dry run to get number of channels for FPN
  114. inp = torch.randn(2, 3, 224, 224)
  115. with torch.no_grad():
  116. out = self.body(inp)
  117. in_channels_list = [o.shape[1] for o in out.values()]
  118. # Build FPN
  119. self.out_channels = 256
  120. self.fpn = FeaturePyramidNetwork(
  121. in_channels_list, out_channels=self.out_channels,
  122. extra_blocks=LastLevelMaxPool())
  123. def forward(self, x):
  124. x = self.body(x)
  125. x = self.fpn(x)
  126. return x
  127. # Now we can build our model!
  128. model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
  129. API Reference
  130. -------------
  131. .. autosummary::
  132. :toctree: generated/
  133. :template: function.rst
  134. create_feature_extractor
  135. get_graph_node_names