roi_align.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from typing import List, Union
  2. import torch
  3. import torch._dynamo
  4. import torch.fx
  5. from torch import nn, Tensor
  6. from torch.jit.annotations import BroadcastingList2
  7. from torch.nn.modules.utils import _pair
  8. from torchvision.extension import _assert_has_ops, _has_ops
  9. from ..utils import _log_api_usage_once
  10. from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
  11. # NB: all inputs are tensors
  12. def _bilinear_interpolate(
  13. input, # [N, C, H, W]
  14. roi_batch_ind, # [K]
  15. y, # [K, PH, IY]
  16. x, # [K, PW, IX]
  17. ymask, # [K, IY]
  18. xmask, # [K, IX]
  19. ):
  20. _, channels, height, width = input.size()
  21. # deal with inverse element out of feature map boundary
  22. y = y.clamp(min=0)
  23. x = x.clamp(min=0)
  24. y_low = y.int()
  25. x_low = x.int()
  26. y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
  27. y_low = torch.where(y_low >= height - 1, height - 1, y_low)
  28. y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
  29. x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
  30. x_low = torch.where(x_low >= width - 1, width - 1, x_low)
  31. x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
  32. ly = y - y_low
  33. lx = x - x_low
  34. hy = 1.0 - ly
  35. hx = 1.0 - lx
  36. # do bilinear interpolation, but respect the masking!
  37. # TODO: It's possible the masking here is unnecessary if y and
  38. # x were clamped appropriately; hard to tell
  39. def masked_index(
  40. y, # [K, PH, IY]
  41. x, # [K, PW, IX]
  42. ):
  43. if ymask is not None:
  44. assert xmask is not None
  45. y = torch.where(ymask[:, None, :], y, 0)
  46. x = torch.where(xmask[:, None, :], x, 0)
  47. return input[
  48. roi_batch_ind[:, None, None, None, None, None],
  49. torch.arange(channels, device=input.device)[None, :, None, None, None, None],
  50. y[:, None, :, None, :, None], # prev [K, PH, IY]
  51. x[:, None, None, :, None, :], # prev [K, PW, IX]
  52. ] # [K, C, PH, PW, IY, IX]
  53. v1 = masked_index(y_low, x_low)
  54. v2 = masked_index(y_low, x_high)
  55. v3 = masked_index(y_high, x_low)
  56. v4 = masked_index(y_high, x_high)
  57. # all ws preemptively [K, C, PH, PW, IY, IX]
  58. def outer_prod(y, x):
  59. return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
  60. w1 = outer_prod(hy, hx)
  61. w2 = outer_prod(hy, lx)
  62. w3 = outer_prod(ly, hx)
  63. w4 = outer_prod(ly, lx)
  64. val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
  65. return val
  66. # TODO: this doesn't actually cache
  67. # TODO: main library should make this easier to do
  68. def maybe_cast(tensor):
  69. if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
  70. return tensor.float()
  71. else:
  72. return tensor
  73. # This is a slow but pure Python and differentiable implementation of
  74. # roi_align. It potentially is a good basis for Inductor compilation
  75. # (but I have not benchmarked it) but today it is solely used for the
  76. # fact that its backwards can be implemented deterministically,
  77. # which is needed for the PT2 benchmark suite.
  78. #
  79. # It is transcribed directly off of the roi_align CUDA kernel, see
  80. # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
  81. @torch._dynamo.allow_in_graph
  82. def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  83. orig_dtype = input.dtype
  84. input = maybe_cast(input)
  85. rois = maybe_cast(rois)
  86. _, _, height, width = input.size()
  87. ph = torch.arange(pooled_height, device=input.device) # [PH]
  88. pw = torch.arange(pooled_width, device=input.device) # [PW]
  89. # input: [N, C, H, W]
  90. # rois: [K, 5]
  91. roi_batch_ind = rois[:, 0].int() # [K]
  92. offset = 0.5 if aligned else 0.0
  93. roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
  94. roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
  95. roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
  96. roi_end_h = rois[:, 4] * spatial_scale - offset # [K]
  97. roi_width = roi_end_w - roi_start_w # [K]
  98. roi_height = roi_end_h - roi_start_h # [K]
  99. if not aligned:
  100. roi_width = torch.clamp(roi_width, min=1.0) # [K]
  101. roi_height = torch.clamp(roi_height, min=1.0) # [K]
  102. bin_size_h = roi_height / pooled_height # [K]
  103. bin_size_w = roi_width / pooled_width # [K]
  104. exact_sampling = sampling_ratio > 0
  105. roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
  106. roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]
  107. """
  108. iy, ix = dims(2)
  109. """
  110. if exact_sampling:
  111. count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
  112. iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
  113. ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
  114. ymask = None
  115. xmask = None
  116. else:
  117. count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
  118. # When doing adaptive sampling, the number of samples we need to do
  119. # is data-dependent based on how big the ROIs are. This is a bit
  120. # awkward because first-class dims can't actually handle this.
  121. # So instead, we inefficiently suppose that we needed to sample ALL
  122. # the points and mask out things that turned out to be unnecessary
  123. iy = torch.arange(height, device=input.device) # [IY]
  124. ix = torch.arange(width, device=input.device) # [IX]
  125. ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
  126. xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]
  127. def from_K(t):
  128. return t[:, None, None]
  129. y = (
  130. from_K(roi_start_h)
  131. + ph[None, :, None] * from_K(bin_size_h)
  132. + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
  133. ) # [K, PH, IY]
  134. x = (
  135. from_K(roi_start_w)
  136. + pw[None, :, None] * from_K(bin_size_w)
  137. + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
  138. ) # [K, PW, IX]
  139. val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
  140. # Mask out samples that weren't actually adaptively needed
  141. if not exact_sampling:
  142. val = torch.where(ymask[:, None, None, None, :, None], val, 0)
  143. val = torch.where(xmask[:, None, None, None, None, :], val, 0)
  144. output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
  145. if isinstance(count, torch.Tensor):
  146. output /= count[:, None, None, None]
  147. else:
  148. output /= count
  149. output = output.to(orig_dtype)
  150. return output
  151. @torch.fx.wrap
  152. def roi_align(
  153. input: Tensor,
  154. boxes: Union[Tensor, List[Tensor]],
  155. output_size: BroadcastingList2[int],
  156. spatial_scale: float = 1.0,
  157. sampling_ratio: int = -1,
  158. aligned: bool = False,
  159. ) -> Tensor:
  160. """
  161. Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN.
  162. Args:
  163. input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
  164. contains ``C`` feature maps of dimensions ``H x W``.
  165. If the tensor is quantized, we expect a batch size of ``N == 1``.
  166. boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
  167. format where the regions will be taken from.
  168. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  169. If a single Tensor is passed, then the first column should
  170. contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
  171. If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
  172. in the batch.
  173. output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling
  174. is performed, as (height, width).
  175. spatial_scale (float): a scaling factor that maps the box coordinates to
  176. the input coordinates. For example, if your boxes are defined on the scale
  177. of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
  178. the original image), you'll want to set this to 0.5. Default: 1.0
  179. sampling_ratio (int): number of sampling points in the interpolation grid
  180. used to compute the output value of each pooled output bin. If > 0,
  181. then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If
  182. <= 0, then an adaptive number of grid points are used (computed as
  183. ``ceil(roi_width / output_width)``, and likewise for height). Default: -1
  184. aligned (bool): If False, use the legacy implementation.
  185. If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two
  186. neighboring pixel indices. This version is used in Detectron2
  187. Returns:
  188. Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
  189. """
  190. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  191. _log_api_usage_once(roi_align)
  192. check_roi_boxes_shape(boxes)
  193. rois = boxes
  194. output_size = _pair(output_size)
  195. if not isinstance(rois, torch.Tensor):
  196. rois = convert_boxes_to_roi_format(rois)
  197. if not torch.jit.is_scripting():
  198. if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
  199. return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
  200. _assert_has_ops()
  201. return torch.ops.torchvision.roi_align(
  202. input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
  203. )
  204. class RoIAlign(nn.Module):
  205. """
  206. See :func:`roi_align`.
  207. """
  208. def __init__(
  209. self,
  210. output_size: BroadcastingList2[int],
  211. spatial_scale: float,
  212. sampling_ratio: int,
  213. aligned: bool = False,
  214. ):
  215. super().__init__()
  216. _log_api_usage_once(self)
  217. self.output_size = output_size
  218. self.spatial_scale = spatial_scale
  219. self.sampling_ratio = sampling_ratio
  220. self.aligned = aligned
  221. def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
  222. return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
  223. def __repr__(self) -> str:
  224. s = (
  225. f"{self.__class__.__name__}("
  226. f"output_size={self.output_size}"
  227. f", spatial_scale={self.spatial_scale}"
  228. f", sampling_ratio={self.sampling_ratio}"
  229. f", aligned={self.aligned}"
  230. f")"
  231. )
  232. return s