backbone_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import warnings
  2. from typing import Callable, Dict, List, Optional, Union
  3. from torch import nn, Tensor
  4. from torchvision.ops import misc as misc_nn_ops
  5. from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
  6. from .. import mobilenet, resnet
  7. from .._api import _get_enum_from_fn, WeightsEnum
  8. from .._utils import handle_legacy_interface, IntermediateLayerGetter
  9. class BackboneWithFPN(nn.Module):
  10. """
  11. Adds a FPN on top of a model.
  12. Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
  13. extract a submodel that returns the feature maps specified in return_layers.
  14. The same limitations of IntermediateLayerGetter apply here.
  15. Args:
  16. backbone (nn.Module)
  17. return_layers (Dict[name, new_name]): a dict containing the names
  18. of the modules for which the activations will be returned as
  19. the key of the dict, and the value of the dict is the name
  20. of the returned activation (which the user can specify).
  21. in_channels_list (List[int]): number of channels for each feature map
  22. that is returned, in the order they are present in the OrderedDict
  23. out_channels (int): number of channels in the FPN.
  24. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  25. Attributes:
  26. out_channels (int): the number of channels in the FPN
  27. """
  28. def __init__(
  29. self,
  30. backbone: nn.Module,
  31. return_layers: Dict[str, str],
  32. in_channels_list: List[int],
  33. out_channels: int,
  34. extra_blocks: Optional[ExtraFPNBlock] = None,
  35. norm_layer: Optional[Callable[..., nn.Module]] = None,
  36. ) -> None:
  37. super().__init__()
  38. if extra_blocks is None:
  39. extra_blocks = LastLevelMaxPool()
  40. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  41. self.fpn = FeaturePyramidNetwork(
  42. in_channels_list=in_channels_list,
  43. out_channels=out_channels,
  44. extra_blocks=extra_blocks,
  45. norm_layer=norm_layer,
  46. )
  47. self.out_channels = out_channels
  48. def forward(self, x: Tensor) -> Dict[str, Tensor]:
  49. x = self.body(x)
  50. x = self.fpn(x)
  51. return x
  52. @handle_legacy_interface(
  53. weights=(
  54. "pretrained",
  55. lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
  56. ),
  57. )
  58. def resnet_fpn_backbone(
  59. *,
  60. backbone_name: str,
  61. weights: Optional[WeightsEnum],
  62. norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
  63. trainable_layers: int = 3,
  64. returned_layers: Optional[List[int]] = None,
  65. extra_blocks: Optional[ExtraFPNBlock] = None,
  66. ) -> BackboneWithFPN:
  67. """
  68. Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
  69. Examples::
  70. >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
  71. >>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
  72. >>> # get some dummy image
  73. >>> x = torch.rand(1,3,64,64)
  74. >>> # compute the output
  75. >>> output = backbone(x)
  76. >>> print([(k, v.shape) for k, v in output.items()])
  77. >>> # returns
  78. >>> [('0', torch.Size([1, 256, 16, 16])),
  79. >>> ('1', torch.Size([1, 256, 8, 8])),
  80. >>> ('2', torch.Size([1, 256, 4, 4])),
  81. >>> ('3', torch.Size([1, 256, 2, 2])),
  82. >>> ('pool', torch.Size([1, 256, 1, 1]))]
  83. Args:
  84. backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
  85. 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
  86. weights (WeightsEnum, optional): The pretrained weights for the model
  87. norm_layer (callable): it is recommended to use the default value. For details visit:
  88. (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
  89. trainable_layers (int): number of trainable (not frozen) layers starting from final block.
  90. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
  91. returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
  92. By default, all layers are returned.
  93. extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
  94. be performed. It is expected to take the fpn features, the original
  95. features and the names of the original features as input, and returns
  96. a new list of feature maps and their corresponding names. By
  97. default, a ``LastLevelMaxPool`` is used.
  98. """
  99. backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
  100. return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
  101. def _resnet_fpn_extractor(
  102. backbone: resnet.ResNet,
  103. trainable_layers: int,
  104. returned_layers: Optional[List[int]] = None,
  105. extra_blocks: Optional[ExtraFPNBlock] = None,
  106. norm_layer: Optional[Callable[..., nn.Module]] = None,
  107. ) -> BackboneWithFPN:
  108. # select layers that won't be frozen
  109. if trainable_layers < 0 or trainable_layers > 5:
  110. raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
  111. layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
  112. if trainable_layers == 5:
  113. layers_to_train.append("bn1")
  114. for name, parameter in backbone.named_parameters():
  115. if all([not name.startswith(layer) for layer in layers_to_train]):
  116. parameter.requires_grad_(False)
  117. if extra_blocks is None:
  118. extra_blocks = LastLevelMaxPool()
  119. if returned_layers is None:
  120. returned_layers = [1, 2, 3, 4]
  121. if min(returned_layers) <= 0 or max(returned_layers) >= 5:
  122. raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
  123. return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
  124. in_channels_stage2 = backbone.inplanes // 8
  125. in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
  126. out_channels = 256
  127. return BackboneWithFPN(
  128. backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
  129. )
  130. def _validate_trainable_layers(
  131. is_trained: bool,
  132. trainable_backbone_layers: Optional[int],
  133. max_value: int,
  134. default_value: int,
  135. ) -> int:
  136. # don't freeze any layers if pretrained model or backbone is not used
  137. if not is_trained:
  138. if trainable_backbone_layers is not None:
  139. warnings.warn(
  140. "Changing trainable_backbone_layers has no effect if "
  141. "neither pretrained nor pretrained_backbone have been set to True, "
  142. f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
  143. )
  144. trainable_backbone_layers = max_value
  145. # by default freeze first blocks
  146. if trainable_backbone_layers is None:
  147. trainable_backbone_layers = default_value
  148. if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
  149. raise ValueError(
  150. f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
  151. )
  152. return trainable_backbone_layers
  153. @handle_legacy_interface(
  154. weights=(
  155. "pretrained",
  156. lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
  157. ),
  158. )
  159. def mobilenet_backbone(
  160. *,
  161. backbone_name: str,
  162. weights: Optional[WeightsEnum],
  163. fpn: bool,
  164. norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
  165. trainable_layers: int = 2,
  166. returned_layers: Optional[List[int]] = None,
  167. extra_blocks: Optional[ExtraFPNBlock] = None,
  168. ) -> nn.Module:
  169. backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
  170. return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
  171. def _mobilenet_extractor(
  172. backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
  173. fpn: bool,
  174. trainable_layers: int,
  175. returned_layers: Optional[List[int]] = None,
  176. extra_blocks: Optional[ExtraFPNBlock] = None,
  177. norm_layer: Optional[Callable[..., nn.Module]] = None,
  178. ) -> nn.Module:
  179. backbone = backbone.features
  180. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
  181. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
  182. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
  183. num_stages = len(stage_indices)
  184. # find the index of the layer from which we won't freeze
  185. if trainable_layers < 0 or trainable_layers > num_stages:
  186. raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
  187. freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
  188. for b in backbone[:freeze_before]:
  189. for parameter in b.parameters():
  190. parameter.requires_grad_(False)
  191. out_channels = 256
  192. if fpn:
  193. if extra_blocks is None:
  194. extra_blocks = LastLevelMaxPool()
  195. if returned_layers is None:
  196. returned_layers = [num_stages - 2, num_stages - 1]
  197. if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
  198. raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
  199. return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
  200. in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
  201. return BackboneWithFPN(
  202. backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
  203. )
  204. else:
  205. m = nn.Sequential(
  206. backbone,
  207. # depthwise linear combination of channels to reduce their size
  208. nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
  209. )
  210. m.out_channels = out_channels # type: ignore[assignment]
  211. return m