123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- import warnings
- from typing import Callable, Dict, List, Optional, Union
- from torch import nn, Tensor
- from torchvision.ops import misc as misc_nn_ops
- from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
- from .. import mobilenet, resnet
- from .._api import _get_enum_from_fn, WeightsEnum
- from .._utils import handle_legacy_interface, IntermediateLayerGetter
- class BackboneWithFPN(nn.Module):
- """
- Adds a FPN on top of a model.
- Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
- extract a submodel that returns the feature maps specified in return_layers.
- The same limitations of IntermediateLayerGetter apply here.
- Args:
- backbone (nn.Module)
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- in_channels_list (List[int]): number of channels for each feature map
- that is returned, in the order they are present in the OrderedDict
- out_channels (int): number of channels in the FPN.
- norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
- Attributes:
- out_channels (int): the number of channels in the FPN
- """
- def __init__(
- self,
- backbone: nn.Module,
- return_layers: Dict[str, str],
- in_channels_list: List[int],
- out_channels: int,
- extra_blocks: Optional[ExtraFPNBlock] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- ) -> None:
- super().__init__()
- if extra_blocks is None:
- extra_blocks = LastLevelMaxPool()
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.fpn = FeaturePyramidNetwork(
- in_channels_list=in_channels_list,
- out_channels=out_channels,
- extra_blocks=extra_blocks,
- norm_layer=norm_layer,
- )
- self.out_channels = out_channels
- def forward(self, x: Tensor) -> Dict[str, Tensor]:
- x = self.body(x)
- x = self.fpn(x)
- return x
- @handle_legacy_interface(
- weights=(
- "pretrained",
- lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
- ),
- )
- def resnet_fpn_backbone(
- *,
- backbone_name: str,
- weights: Optional[WeightsEnum],
- norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
- trainable_layers: int = 3,
- returned_layers: Optional[List[int]] = None,
- extra_blocks: Optional[ExtraFPNBlock] = None,
- ) -> BackboneWithFPN:
- """
- Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
- Examples::
- >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
- >>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
- >>> # get some dummy image
- >>> x = torch.rand(1,3,64,64)
- >>> # compute the output
- >>> output = backbone(x)
- >>> print([(k, v.shape) for k, v in output.items()])
- >>> # returns
- >>> [('0', torch.Size([1, 256, 16, 16])),
- >>> ('1', torch.Size([1, 256, 8, 8])),
- >>> ('2', torch.Size([1, 256, 4, 4])),
- >>> ('3', torch.Size([1, 256, 2, 2])),
- >>> ('pool', torch.Size([1, 256, 1, 1]))]
- Args:
- backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
- 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
- weights (WeightsEnum, optional): The pretrained weights for the model
- norm_layer (callable): it is recommended to use the default value. For details visit:
- (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
- trainable_layers (int): number of trainable (not frozen) layers starting from final block.
- Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
- returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
- By default, all layers are returned.
- extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
- be performed. It is expected to take the fpn features, the original
- features and the names of the original features as input, and returns
- a new list of feature maps and their corresponding names. By
- default, a ``LastLevelMaxPool`` is used.
- """
- backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
- return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
- def _resnet_fpn_extractor(
- backbone: resnet.ResNet,
- trainable_layers: int,
- returned_layers: Optional[List[int]] = None,
- extra_blocks: Optional[ExtraFPNBlock] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- ) -> BackboneWithFPN:
- # select layers that won't be frozen
- if trainable_layers < 0 or trainable_layers > 5:
- raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
- layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
- if trainable_layers == 5:
- layers_to_train.append("bn1")
- for name, parameter in backbone.named_parameters():
- if all([not name.startswith(layer) for layer in layers_to_train]):
- parameter.requires_grad_(False)
- if extra_blocks is None:
- extra_blocks = LastLevelMaxPool()
- if returned_layers is None:
- returned_layers = [1, 2, 3, 4]
- if min(returned_layers) <= 0 or max(returned_layers) >= 5:
- raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
- return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
- in_channels_stage2 = backbone.inplanes // 8
- in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
- out_channels = 256
- return BackboneWithFPN(
- backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
- )
- def _validate_trainable_layers(
- is_trained: bool,
- trainable_backbone_layers: Optional[int],
- max_value: int,
- default_value: int,
- ) -> int:
- # don't freeze any layers if pretrained model or backbone is not used
- if not is_trained:
- if trainable_backbone_layers is not None:
- warnings.warn(
- "Changing trainable_backbone_layers has no effect if "
- "neither pretrained nor pretrained_backbone have been set to True, "
- f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
- )
- trainable_backbone_layers = max_value
- # by default freeze first blocks
- if trainable_backbone_layers is None:
- trainable_backbone_layers = default_value
- if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
- raise ValueError(
- f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
- )
- return trainable_backbone_layers
- @handle_legacy_interface(
- weights=(
- "pretrained",
- lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
- ),
- )
- def mobilenet_backbone(
- *,
- backbone_name: str,
- weights: Optional[WeightsEnum],
- fpn: bool,
- norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
- trainable_layers: int = 2,
- returned_layers: Optional[List[int]] = None,
- extra_blocks: Optional[ExtraFPNBlock] = None,
- ) -> nn.Module:
- backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
- return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
- def _mobilenet_extractor(
- backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
- fpn: bool,
- trainable_layers: int,
- returned_layers: Optional[List[int]] = None,
- extra_blocks: Optional[ExtraFPNBlock] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- ) -> nn.Module:
- backbone = backbone.features
- # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
- # The first and last blocks are always included because they are the C0 (conv1) and Cn.
- stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
- num_stages = len(stage_indices)
- # find the index of the layer from which we won't freeze
- if trainable_layers < 0 or trainable_layers > num_stages:
- raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
- freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
- for b in backbone[:freeze_before]:
- for parameter in b.parameters():
- parameter.requires_grad_(False)
- out_channels = 256
- if fpn:
- if extra_blocks is None:
- extra_blocks = LastLevelMaxPool()
- if returned_layers is None:
- returned_layers = [num_stages - 2, num_stages - 1]
- if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
- raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
- return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
- in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
- return BackboneWithFPN(
- backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
- )
- else:
- m = nn.Sequential(
- backbone,
- # depthwise linear combination of channels to reduce their size
- nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
- )
- m.out_channels = out_channels # type: ignore[assignment]
- return m
|