poolers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. from typing import Dict, List, Optional, Tuple, Union
  2. import torch
  3. import torch.fx
  4. import torchvision
  5. from torch import nn, Tensor
  6. from torchvision.ops.boxes import box_area
  7. from ..utils import _log_api_usage_once
  8. from .roi_align import roi_align
  9. # copying result_idx_in_level to a specific index in result[]
  10. # is not supported by ONNX tracing yet.
  11. # _onnx_merge_levels() is an implementation supported by ONNX
  12. # that merges the levels to the right indices
  13. @torch.jit.unused
  14. def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
  15. first_result = unmerged_results[0]
  16. dtype, device = first_result.dtype, first_result.device
  17. res = torch.zeros(
  18. (levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device
  19. )
  20. for level in range(len(unmerged_results)):
  21. index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
  22. index = index.expand(
  23. index.size(0),
  24. unmerged_results[level].size(1),
  25. unmerged_results[level].size(2),
  26. unmerged_results[level].size(3),
  27. )
  28. res = res.scatter(0, index, unmerged_results[level])
  29. return res
  30. # TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
  31. def initLevelMapper(
  32. k_min: int,
  33. k_max: int,
  34. canonical_scale: int = 224,
  35. canonical_level: int = 4,
  36. eps: float = 1e-6,
  37. ):
  38. return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
  39. class LevelMapper:
  40. """Determine which FPN level each RoI in a set of RoIs should map to based
  41. on the heuristic in the FPN paper.
  42. Args:
  43. k_min (int)
  44. k_max (int)
  45. canonical_scale (int)
  46. canonical_level (int)
  47. eps (float)
  48. """
  49. def __init__(
  50. self,
  51. k_min: int,
  52. k_max: int,
  53. canonical_scale: int = 224,
  54. canonical_level: int = 4,
  55. eps: float = 1e-6,
  56. ):
  57. self.k_min = k_min
  58. self.k_max = k_max
  59. self.s0 = canonical_scale
  60. self.lvl0 = canonical_level
  61. self.eps = eps
  62. def __call__(self, boxlists: List[Tensor]) -> Tensor:
  63. """
  64. Args:
  65. boxlists (list[BoxList])
  66. """
  67. # Compute level ids
  68. s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))
  69. # Eqn.(1) in FPN paper
  70. target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
  71. target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
  72. return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
  73. def _convert_to_roi_format(boxes: List[Tensor]) -> Tensor:
  74. concat_boxes = torch.cat(boxes, dim=0)
  75. device, dtype = concat_boxes.device, concat_boxes.dtype
  76. ids = torch.cat(
  77. [torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device) for i, b in enumerate(boxes)],
  78. dim=0,
  79. )
  80. rois = torch.cat([ids, concat_boxes], dim=1)
  81. return rois
  82. def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
  83. # assumption: the scale is of the form 2 ** (-k), with k integer
  84. size = feature.shape[-2:]
  85. possible_scales: List[float] = []
  86. for s1, s2 in zip(size, original_size):
  87. approx_scale = float(s1) / float(s2)
  88. scale = 2 ** float(torch.tensor(approx_scale).log2().round())
  89. possible_scales.append(scale)
  90. return possible_scales[0]
  91. @torch.fx.wrap
  92. def _setup_scales(
  93. features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int
  94. ) -> Tuple[List[float], LevelMapper]:
  95. if not image_shapes:
  96. raise ValueError("images list should not be empty")
  97. max_x = 0
  98. max_y = 0
  99. for shape in image_shapes:
  100. max_x = max(shape[0], max_x)
  101. max_y = max(shape[1], max_y)
  102. original_input_shape = (max_x, max_y)
  103. scales = [_infer_scale(feat, original_input_shape) for feat in features]
  104. # get the levels in the feature map by leveraging the fact that the network always
  105. # downsamples by a factor of 2 at each level.
  106. lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
  107. lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
  108. map_levels = initLevelMapper(
  109. int(lvl_min),
  110. int(lvl_max),
  111. canonical_scale=canonical_scale,
  112. canonical_level=canonical_level,
  113. )
  114. return scales, map_levels
  115. @torch.fx.wrap
  116. def _filter_input(x: Dict[str, Tensor], featmap_names: List[str]) -> List[Tensor]:
  117. x_filtered = []
  118. for k, v in x.items():
  119. if k in featmap_names:
  120. x_filtered.append(v)
  121. return x_filtered
  122. @torch.fx.wrap
  123. def _multiscale_roi_align(
  124. x_filtered: List[Tensor],
  125. boxes: List[Tensor],
  126. output_size: List[int],
  127. sampling_ratio: int,
  128. scales: Optional[List[float]],
  129. mapper: Optional[LevelMapper],
  130. ) -> Tensor:
  131. """
  132. Args:
  133. x_filtered (List[Tensor]): List of input tensors.
  134. boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
  135. (x1, y1, x2, y2) format and in the image reference size, not the feature map
  136. reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  137. output_size (Union[List[Tuple[int, int]], List[int]]): size of the output
  138. sampling_ratio (int): sampling ratio for ROIAlign
  139. scales (Optional[List[float]]): If None, scales will be automatically inferred. Default value is None.
  140. mapper (Optional[LevelMapper]): If none, mapper will be automatically inferred. Default value is None.
  141. Returns:
  142. result (Tensor)
  143. """
  144. if scales is None or mapper is None:
  145. raise ValueError("scales and mapper should not be None")
  146. num_levels = len(x_filtered)
  147. rois = _convert_to_roi_format(boxes)
  148. if num_levels == 1:
  149. return roi_align(
  150. x_filtered[0],
  151. rois,
  152. output_size=output_size,
  153. spatial_scale=scales[0],
  154. sampling_ratio=sampling_ratio,
  155. )
  156. levels = mapper(boxes)
  157. num_rois = len(rois)
  158. num_channels = x_filtered[0].shape[1]
  159. dtype, device = x_filtered[0].dtype, x_filtered[0].device
  160. result = torch.zeros(
  161. (
  162. num_rois,
  163. num_channels,
  164. )
  165. + output_size,
  166. dtype=dtype,
  167. device=device,
  168. )
  169. tracing_results = []
  170. for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
  171. idx_in_level = torch.where(levels == level)[0]
  172. rois_per_level = rois[idx_in_level]
  173. result_idx_in_level = roi_align(
  174. per_level_feature,
  175. rois_per_level,
  176. output_size=output_size,
  177. spatial_scale=scale,
  178. sampling_ratio=sampling_ratio,
  179. )
  180. if torchvision._is_tracing():
  181. tracing_results.append(result_idx_in_level.to(dtype))
  182. else:
  183. # result and result_idx_in_level's dtypes are based on dtypes of different
  184. # elements in x_filtered. x_filtered contains tensors output by different
  185. # layers. When autocast is active, it may choose different dtypes for
  186. # different layers' outputs. Therefore, we defensively match result's dtype
  187. # before copying elements from result_idx_in_level in the following op.
  188. # We need to cast manually (can't rely on autocast to cast for us) because
  189. # the op acts on result in-place, and autocast only affects out-of-place ops.
  190. result[idx_in_level] = result_idx_in_level.to(result.dtype)
  191. if torchvision._is_tracing():
  192. result = _onnx_merge_levels(levels, tracing_results)
  193. return result
  194. class MultiScaleRoIAlign(nn.Module):
  195. """
  196. Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
  197. It infers the scale of the pooling via the heuristics specified in eq. 1
  198. of the `Feature Pyramid Network paper <https://arxiv.org/abs/1612.03144>`_.
  199. They keyword-only parameters ``canonical_scale`` and ``canonical_level``
  200. correspond respectively to ``224`` and ``k0=4`` in eq. 1, and
  201. have the following meaning: ``canonical_level`` is the target level of the pyramid from
  202. which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``.
  203. Args:
  204. featmap_names (List[str]): the names of the feature maps that will be used
  205. for the pooling.
  206. output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
  207. sampling_ratio (int): sampling ratio for ROIAlign
  208. canonical_scale (int, optional): canonical_scale for LevelMapper
  209. canonical_level (int, optional): canonical_level for LevelMapper
  210. Examples::
  211. >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
  212. >>> i = OrderedDict()
  213. >>> i['feat1'] = torch.rand(1, 5, 64, 64)
  214. >>> i['feat2'] = torch.rand(1, 5, 32, 32) # this feature won't be used in the pooling
  215. >>> i['feat3'] = torch.rand(1, 5, 16, 16)
  216. >>> # create some random bounding boxes
  217. >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
  218. >>> # original image size, before computing the feature maps
  219. >>> image_sizes = [(512, 512)]
  220. >>> output = m(i, [boxes], image_sizes)
  221. >>> print(output.shape)
  222. >>> torch.Size([6, 5, 3, 3])
  223. """
  224. __annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]}
  225. def __init__(
  226. self,
  227. featmap_names: List[str],
  228. output_size: Union[int, Tuple[int], List[int]],
  229. sampling_ratio: int,
  230. *,
  231. canonical_scale: int = 224,
  232. canonical_level: int = 4,
  233. ):
  234. super().__init__()
  235. _log_api_usage_once(self)
  236. if isinstance(output_size, int):
  237. output_size = (output_size, output_size)
  238. self.featmap_names = featmap_names
  239. self.sampling_ratio = sampling_ratio
  240. self.output_size = tuple(output_size)
  241. self.scales = None
  242. self.map_levels = None
  243. self.canonical_scale = canonical_scale
  244. self.canonical_level = canonical_level
  245. def forward(
  246. self,
  247. x: Dict[str, Tensor],
  248. boxes: List[Tensor],
  249. image_shapes: List[Tuple[int, int]],
  250. ) -> Tensor:
  251. """
  252. Args:
  253. x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
  254. all the same number of channels, but they can have different sizes.
  255. boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
  256. (x1, y1, x2, y2) format and in the image reference size, not the feature map
  257. reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  258. image_shapes (List[Tuple[height, width]]): the sizes of each image before they
  259. have been fed to a CNN to obtain feature maps. This allows us to infer the
  260. scale factor for each one of the levels to be pooled.
  261. Returns:
  262. result (Tensor)
  263. """
  264. x_filtered = _filter_input(x, self.featmap_names)
  265. if self.scales is None or self.map_levels is None:
  266. self.scales, self.map_levels = _setup_scales(
  267. x_filtered, image_shapes, self.canonical_scale, self.canonical_level
  268. )
  269. return _multiscale_roi_align(
  270. x_filtered,
  271. boxes,
  272. self.output_size,
  273. self.sampling_ratio,
  274. self.scales,
  275. self.map_levels,
  276. )
  277. def __repr__(self) -> str:
  278. return (
  279. f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
  280. f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})"
  281. )