123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- Feature extraction for model inspection
- =======================================
- .. currentmodule:: torchvision.models.feature_extraction
- The ``torchvision.models.feature_extraction`` package contains
- feature extraction utilities that let us tap into our models to access intermediate
- transformations of our inputs. This could be useful for a variety of
- applications in computer vision. Just a few examples are:
- - Visualizing feature maps.
- - Extracting features to compute image descriptors for tasks like facial
- recognition, copy-detection, or image retrieval.
- - Passing selected features to downstream sub-networks for end-to-end training
- with a specific task in mind. For example, passing a hierarchy of features
- to a Feature Pyramid Network with object detection heads.
- Torchvision provides :func:`create_feature_extractor` for this purpose.
- It works by following roughly these steps:
- 1. Symbolically tracing the model to get a graphical representation of
- how it transforms the input, step by step.
- 2. Setting the user-selected graph nodes as outputs.
- 3. Removing all redundant nodes (anything downstream of the output nodes).
- 4. Generating python code from the resulting graph and bundling that into a
- PyTorch module together with the graph itself.
- |
- The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
- provides a more general and detailed explanation of the above procedure and
- the inner workings of the symbolic tracing.
- .. _about-node-names:
- **About Node Names**
- In order to specify which nodes should be output nodes for extracted
- features, one should be familiar with the node naming convention used here
- (which differs slightly from that used in ``torch.fx``). A node name is
- specified as a ``.`` separated path walking the module hierarchy from top level
- module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
- in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
- layer of the ``ResNet`` module. Here are some finer points to keep in mind:
- - When specifying node names for :func:`create_feature_extractor`, you may
- provide a truncated version of a node name as a shortcut. To see how this
- works, try creating a ResNet-50 model and printing the node names with
- ``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
- observe that the last node pertaining to ``layer4`` is
- ``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
- node, or just ``"layer4"`` as this, by convention, refers to the last node
- (in order of execution) of ``layer4``.
- - If a certain module or operation is repeated more than once, node names get
- an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
- addition (``+``) operation is used three times in the same ``forward``
- method. Then there would be ``"path.to.module.add"``,
- ``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
- maintained within the scope of the direct parent. So in ResNet-50 there is
- a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
- operations reside in different blocks, there is no need for a postfix to
- disambiguate.
- **An Example**
- Here is an example of how we might extract features for MaskRCNN:
- .. code-block:: python
- import torch
- from torchvision.models import resnet50
- from torchvision.models.feature_extraction import get_graph_node_names
- from torchvision.models.feature_extraction import create_feature_extractor
- from torchvision.models.detection.mask_rcnn import MaskRCNN
- from torchvision.models.detection.backbone_utils import LastLevelMaxPool
- from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
- # To assist you in designing the feature extractor you may want to print out
- # the available nodes for resnet50.
- m = resnet50()
- train_nodes, eval_nodes = get_graph_node_names(resnet50())
- # The lists returned, are the names of all the graph nodes (in order of
- # execution) for the input model traced in train mode and in eval mode
- # respectively. You'll find that `train_nodes` and `eval_nodes` are the same
- # for this example. But if the model contains control flow that's dependent
- # on the training mode, they may be different.
- # To specify the nodes you want to extract, you could select the final node
- # that appears in each of the main layers:
- return_nodes = {
- # node_name: user-specified key for output dict
- 'layer1.2.relu_2': 'layer1',
- 'layer2.3.relu_2': 'layer2',
- 'layer3.5.relu_2': 'layer3',
- 'layer4.2.relu_2': 'layer4',
- }
- # But `create_feature_extractor` can also accept truncated node specifications
- # like "layer1", as it will just pick the last node that's a descendent of
- # of the specification. (Tip: be careful with this, especially when a layer
- # has multiple outputs. It's not always guaranteed that the last operation
- # performed is the one that corresponds to the output you desire. You should
- # consult the source code for the input model to confirm.)
- return_nodes = {
- 'layer1': 'layer1',
- 'layer2': 'layer2',
- 'layer3': 'layer3',
- 'layer4': 'layer4',
- }
- # Now you can build the feature extractor. This returns a module whose forward
- # method returns a dictionary like:
- # {
- # 'layer1': output of layer 1,
- # 'layer2': output of layer 2,
- # 'layer3': output of layer 3,
- # 'layer4': output of layer 4,
- # }
- create_feature_extractor(m, return_nodes=return_nodes)
- # Let's put all that together to wrap resnet50 with MaskRCNN
- # MaskRCNN requires a backbone with an attached FPN
- class Resnet50WithFPN(torch.nn.Module):
- def __init__(self):
- super(Resnet50WithFPN, self).__init__()
- # Get a resnet50 backbone
- m = resnet50()
- # Extract 4 main layers (note: MaskRCNN needs this particular name
- # mapping for return nodes)
- self.body = create_feature_extractor(
- m, return_nodes={f'layer{k}': str(v)
- for v, k in enumerate([1, 2, 3, 4])})
- # Dry run to get number of channels for FPN
- inp = torch.randn(2, 3, 224, 224)
- with torch.no_grad():
- out = self.body(inp)
- in_channels_list = [o.shape[1] for o in out.values()]
- # Build FPN
- self.out_channels = 256
- self.fpn = FeaturePyramidNetwork(
- in_channels_list, out_channels=self.out_channels,
- extra_blocks=LastLevelMaxPool())
- def forward(self, x):
- x = self.body(x)
- x = self.fpn(x)
- return x
- # Now we can build our model!
- model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
- API Reference
- -------------
- .. autosummary::
- :toctree: generated/
- :template: function.rst
- create_feature_extractor
- get_graph_node_names
|