feature_pyramid_network.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. from collections import OrderedDict
  2. from typing import Callable, Dict, List, Optional, Tuple
  3. import torch.nn.functional as F
  4. from torch import nn, Tensor
  5. from ..ops.misc import Conv2dNormActivation
  6. from ..utils import _log_api_usage_once
  7. class ExtraFPNBlock(nn.Module):
  8. """
  9. Base class for the extra block in the FPN.
  10. Args:
  11. results (List[Tensor]): the result of the FPN
  12. x (List[Tensor]): the original feature maps
  13. names (List[str]): the names for each one of the
  14. original feature maps
  15. Returns:
  16. results (List[Tensor]): the extended set of results
  17. of the FPN
  18. names (List[str]): the extended set of names for the results
  19. """
  20. def forward(
  21. self,
  22. results: List[Tensor],
  23. x: List[Tensor],
  24. names: List[str],
  25. ) -> Tuple[List[Tensor], List[str]]:
  26. pass
  27. class FeaturePyramidNetwork(nn.Module):
  28. """
  29. Module that adds a FPN from on top of a set of feature maps. This is based on
  30. `"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.
  31. The feature maps are currently supposed to be in increasing depth
  32. order.
  33. The input to the model is expected to be an OrderedDict[Tensor], containing
  34. the feature maps on top of which the FPN will be added.
  35. Args:
  36. in_channels_list (list[int]): number of channels for each feature map that
  37. is passed to the module
  38. out_channels (int): number of channels of the FPN representation
  39. extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
  40. be performed. It is expected to take the fpn features, the original
  41. features and the names of the original features as input, and returns
  42. a new list of feature maps and their corresponding names
  43. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  44. Examples::
  45. >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
  46. >>> # get some dummy data
  47. >>> x = OrderedDict()
  48. >>> x['feat0'] = torch.rand(1, 10, 64, 64)
  49. >>> x['feat2'] = torch.rand(1, 20, 16, 16)
  50. >>> x['feat3'] = torch.rand(1, 30, 8, 8)
  51. >>> # compute the FPN on top of x
  52. >>> output = m(x)
  53. >>> print([(k, v.shape) for k, v in output.items()])
  54. >>> # returns
  55. >>> [('feat0', torch.Size([1, 5, 64, 64])),
  56. >>> ('feat2', torch.Size([1, 5, 16, 16])),
  57. >>> ('feat3', torch.Size([1, 5, 8, 8]))]
  58. """
  59. _version = 2
  60. def __init__(
  61. self,
  62. in_channels_list: List[int],
  63. out_channels: int,
  64. extra_blocks: Optional[ExtraFPNBlock] = None,
  65. norm_layer: Optional[Callable[..., nn.Module]] = None,
  66. ):
  67. super().__init__()
  68. _log_api_usage_once(self)
  69. self.inner_blocks = nn.ModuleList()
  70. self.layer_blocks = nn.ModuleList()
  71. for in_channels in in_channels_list:
  72. if in_channels == 0:
  73. raise ValueError("in_channels=0 is currently not supported")
  74. inner_block_module = Conv2dNormActivation(
  75. in_channels, out_channels, kernel_size=1, padding=0, norm_layer=norm_layer, activation_layer=None
  76. )
  77. layer_block_module = Conv2dNormActivation(
  78. out_channels, out_channels, kernel_size=3, norm_layer=norm_layer, activation_layer=None
  79. )
  80. self.inner_blocks.append(inner_block_module)
  81. self.layer_blocks.append(layer_block_module)
  82. # initialize parameters now to avoid modifying the initialization of top_blocks
  83. for m in self.modules():
  84. if isinstance(m, nn.Conv2d):
  85. nn.init.kaiming_uniform_(m.weight, a=1)
  86. if m.bias is not None:
  87. nn.init.constant_(m.bias, 0)
  88. if extra_blocks is not None:
  89. if not isinstance(extra_blocks, ExtraFPNBlock):
  90. raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
  91. self.extra_blocks = extra_blocks
  92. def _load_from_state_dict(
  93. self,
  94. state_dict,
  95. prefix,
  96. local_metadata,
  97. strict,
  98. missing_keys,
  99. unexpected_keys,
  100. error_msgs,
  101. ):
  102. version = local_metadata.get("version", None)
  103. if version is None or version < 2:
  104. num_blocks = len(self.inner_blocks)
  105. for block in ["inner_blocks", "layer_blocks"]:
  106. for i in range(num_blocks):
  107. for type in ["weight", "bias"]:
  108. old_key = f"{prefix}{block}.{i}.{type}"
  109. new_key = f"{prefix}{block}.{i}.0.{type}"
  110. if old_key in state_dict:
  111. state_dict[new_key] = state_dict.pop(old_key)
  112. super()._load_from_state_dict(
  113. state_dict,
  114. prefix,
  115. local_metadata,
  116. strict,
  117. missing_keys,
  118. unexpected_keys,
  119. error_msgs,
  120. )
  121. def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
  122. """
  123. This is equivalent to self.inner_blocks[idx](x),
  124. but torchscript doesn't support this yet
  125. """
  126. num_blocks = len(self.inner_blocks)
  127. if idx < 0:
  128. idx += num_blocks
  129. out = x
  130. for i, module in enumerate(self.inner_blocks):
  131. if i == idx:
  132. out = module(x)
  133. return out
  134. def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
  135. """
  136. This is equivalent to self.layer_blocks[idx](x),
  137. but torchscript doesn't support this yet
  138. """
  139. num_blocks = len(self.layer_blocks)
  140. if idx < 0:
  141. idx += num_blocks
  142. out = x
  143. for i, module in enumerate(self.layer_blocks):
  144. if i == idx:
  145. out = module(x)
  146. return out
  147. def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
  148. """
  149. Computes the FPN for a set of feature maps.
  150. Args:
  151. x (OrderedDict[Tensor]): feature maps for each feature level.
  152. Returns:
  153. results (OrderedDict[Tensor]): feature maps after FPN layers.
  154. They are ordered from the highest resolution first.
  155. """
  156. # unpack OrderedDict into two lists for easier handling
  157. names = list(x.keys())
  158. x = list(x.values())
  159. last_inner = self.get_result_from_inner_blocks(x[-1], -1)
  160. results = []
  161. results.append(self.get_result_from_layer_blocks(last_inner, -1))
  162. for idx in range(len(x) - 2, -1, -1):
  163. inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
  164. feat_shape = inner_lateral.shape[-2:]
  165. inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
  166. last_inner = inner_lateral + inner_top_down
  167. results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
  168. if self.extra_blocks is not None:
  169. results, names = self.extra_blocks(results, x, names)
  170. # make it back an OrderedDict
  171. out = OrderedDict([(k, v) for k, v in zip(names, results)])
  172. return out
  173. class LastLevelMaxPool(ExtraFPNBlock):
  174. """
  175. Applies a max_pool2d (not actual max_pool2d, we just subsample) on top of the last feature map
  176. """
  177. def forward(
  178. self,
  179. x: List[Tensor],
  180. y: List[Tensor],
  181. names: List[str],
  182. ) -> Tuple[List[Tensor], List[str]]:
  183. names.append("pool")
  184. # Use max pooling to simulate stride 2 subsampling
  185. x.append(F.max_pool2d(x[-1], kernel_size=1, stride=2, padding=0))
  186. return x, names
  187. class LastLevelP6P7(ExtraFPNBlock):
  188. """
  189. This module is used in RetinaNet to generate extra layers, P6 and P7.
  190. """
  191. def __init__(self, in_channels: int, out_channels: int):
  192. super().__init__()
  193. self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
  194. self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
  195. for module in [self.p6, self.p7]:
  196. nn.init.kaiming_uniform_(module.weight, a=1)
  197. nn.init.constant_(module.bias, 0)
  198. self.use_P5 = in_channels == out_channels
  199. def forward(
  200. self,
  201. p: List[Tensor],
  202. c: List[Tensor],
  203. names: List[str],
  204. ) -> Tuple[List[Tensor], List[str]]:
  205. p5, c5 = p[-1], c[-1]
  206. x = p5 if self.use_P5 else c5
  207. p6 = self.p6(x)
  208. p7 = self.p7(F.relu(p6))
  209. p.extend([p6, p7])
  210. names.extend(["p6", "p7"])
  211. return p, names