transform.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. import torchvision
  5. from torch import nn, Tensor
  6. from .image_list import ImageList
  7. from .roi_heads import paste_masks_in_image
  8. @torch.jit.unused
  9. def _get_shape_onnx(image: Tensor) -> Tensor:
  10. from torch.onnx import operators
  11. return operators.shape_as_tensor(image)[-2:]
  12. @torch.jit.unused
  13. def _fake_cast_onnx(v: Tensor) -> float:
  14. # ONNX requires a tensor but here we fake its type for JIT.
  15. return v
  16. def _resize_image_and_masks(
  17. image: Tensor,
  18. self_min_size: int,
  19. self_max_size: int,
  20. target: Optional[Dict[str, Tensor]] = None,
  21. fixed_size: Optional[Tuple[int, int]] = None,
  22. ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
  23. if torchvision._is_tracing():
  24. im_shape = _get_shape_onnx(image)
  25. else:
  26. im_shape = torch.tensor(image.shape[-2:])
  27. size: Optional[List[int]] = None
  28. scale_factor: Optional[float] = None
  29. recompute_scale_factor: Optional[bool] = None
  30. if fixed_size is not None:
  31. size = [fixed_size[1], fixed_size[0]]
  32. else:
  33. if torch.jit.is_scripting() or torchvision._is_tracing():
  34. min_size = torch.min(im_shape).to(dtype=torch.float32)
  35. max_size = torch.max(im_shape).to(dtype=torch.float32)
  36. self_min_size_f = float(self_min_size)
  37. self_max_size_f = float(self_max_size)
  38. scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
  39. if torchvision._is_tracing():
  40. scale_factor = _fake_cast_onnx(scale)
  41. else:
  42. scale_factor = scale.item()
  43. else:
  44. # Do it the normal way
  45. min_size = min(im_shape)
  46. max_size = max(im_shape)
  47. scale_factor = min(self_min_size / min_size, self_max_size / max_size)
  48. recompute_scale_factor = True
  49. image = torch.nn.functional.interpolate(
  50. image[None],
  51. size=size,
  52. scale_factor=scale_factor,
  53. mode="bilinear",
  54. recompute_scale_factor=recompute_scale_factor,
  55. align_corners=False,
  56. )[0]
  57. if target is None:
  58. return image, target
  59. if "masks" in target:
  60. mask = target["masks"]
  61. mask = torch.nn.functional.interpolate(
  62. mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
  63. )[:, 0].byte()
  64. target["masks"] = mask
  65. return image, target
  66. class GeneralizedRCNNTransform(nn.Module):
  67. """
  68. Performs input / target transformation before feeding the data to a GeneralizedRCNN
  69. model.
  70. The transformations it performs are:
  71. - input normalization (mean subtraction and std division)
  72. - input / target resizing to match min_size / max_size
  73. It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
  74. """
  75. def __init__(
  76. self,
  77. min_size: int,
  78. max_size: int,
  79. image_mean: List[float],
  80. image_std: List[float],
  81. size_divisible: int = 32,
  82. fixed_size: Optional[Tuple[int, int]] = None,
  83. **kwargs: Any,
  84. ):
  85. super().__init__()
  86. if not isinstance(min_size, (list, tuple)):
  87. min_size = (min_size,)
  88. self.min_size = min_size
  89. self.max_size = max_size
  90. self.image_mean = image_mean
  91. self.image_std = image_std
  92. self.size_divisible = size_divisible
  93. self.fixed_size = fixed_size
  94. self._skip_resize = kwargs.pop("_skip_resize", False)
  95. def forward(
  96. self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
  97. ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
  98. images = [img for img in images]
  99. if targets is not None:
  100. # make a copy of targets to avoid modifying it in-place
  101. # once torchscript supports dict comprehension
  102. # this can be simplified as follows
  103. # targets = [{k: v for k,v in t.items()} for t in targets]
  104. targets_copy: List[Dict[str, Tensor]] = []
  105. for t in targets:
  106. data: Dict[str, Tensor] = {}
  107. for k, v in t.items():
  108. data[k] = v
  109. targets_copy.append(data)
  110. targets = targets_copy
  111. for i in range(len(images)):
  112. image = images[i]
  113. target_index = targets[i] if targets is not None else None
  114. if image.dim() != 3:
  115. raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
  116. image = self.normalize(image)
  117. image, target_index = self.resize(image, target_index)
  118. images[i] = image
  119. if targets is not None and target_index is not None:
  120. targets[i] = target_index
  121. image_sizes = [img.shape[-2:] for img in images]
  122. images = self.batch_images(images, size_divisible=self.size_divisible)
  123. image_sizes_list: List[Tuple[int, int]] = []
  124. for image_size in image_sizes:
  125. torch._assert(
  126. len(image_size) == 2,
  127. f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
  128. )
  129. image_sizes_list.append((image_size[0], image_size[1]))
  130. image_list = ImageList(images, image_sizes_list)
  131. return image_list, targets
  132. def normalize(self, image: Tensor) -> Tensor:
  133. if not image.is_floating_point():
  134. raise TypeError(
  135. f"Expected input images to be of floating type (in range [0, 1]), "
  136. f"but found type {image.dtype} instead"
  137. )
  138. dtype, device = image.dtype, image.device
  139. mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
  140. std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
  141. return (image - mean[:, None, None]) / std[:, None, None]
  142. def torch_choice(self, k: List[int]) -> int:
  143. """
  144. Implements `random.choice` via torch ops, so it can be compiled with
  145. TorchScript and we use PyTorch's RNG (not native RNG)
  146. """
  147. index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
  148. return k[index]
  149. def resize(
  150. self,
  151. image: Tensor,
  152. target: Optional[Dict[str, Tensor]] = None,
  153. ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
  154. h, w = image.shape[-2:]
  155. if self.training:
  156. if self._skip_resize:
  157. return image, target
  158. size = self.torch_choice(self.min_size)
  159. else:
  160. size = self.min_size[-1]
  161. image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
  162. if target is None:
  163. return image, target
  164. bbox = target["boxes"]
  165. bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
  166. target["boxes"] = bbox
  167. if "keypoints" in target:
  168. keypoints = target["keypoints"]
  169. keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
  170. target["keypoints"] = keypoints
  171. return image, target
  172. # _onnx_batch_images() is an implementation of
  173. # batch_images() that is supported by ONNX tracing.
  174. @torch.jit.unused
  175. def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
  176. max_size = []
  177. for i in range(images[0].dim()):
  178. max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
  179. max_size.append(max_size_i)
  180. stride = size_divisible
  181. max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
  182. max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
  183. max_size = tuple(max_size)
  184. # work around for
  185. # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  186. # which is not yet supported in onnx
  187. padded_imgs = []
  188. for img in images:
  189. padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
  190. padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
  191. padded_imgs.append(padded_img)
  192. return torch.stack(padded_imgs)
  193. def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
  194. maxes = the_list[0]
  195. for sublist in the_list[1:]:
  196. for index, item in enumerate(sublist):
  197. maxes[index] = max(maxes[index], item)
  198. return maxes
  199. def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
  200. if torchvision._is_tracing():
  201. # batch_images() does not export well to ONNX
  202. # call _onnx_batch_images() instead
  203. return self._onnx_batch_images(images, size_divisible)
  204. max_size = self.max_by_axis([list(img.shape) for img in images])
  205. stride = float(size_divisible)
  206. max_size = list(max_size)
  207. max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
  208. max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
  209. batch_shape = [len(images)] + max_size
  210. batched_imgs = images[0].new_full(batch_shape, 0)
  211. for i in range(batched_imgs.shape[0]):
  212. img = images[i]
  213. batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  214. return batched_imgs
  215. def postprocess(
  216. self,
  217. result: List[Dict[str, Tensor]],
  218. image_shapes: List[Tuple[int, int]],
  219. original_image_sizes: List[Tuple[int, int]],
  220. ) -> List[Dict[str, Tensor]]:
  221. if self.training:
  222. return result
  223. for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
  224. boxes = pred["boxes"]
  225. boxes = resize_boxes(boxes, im_s, o_im_s)
  226. result[i]["boxes"] = boxes
  227. if "masks" in pred:
  228. masks = pred["masks"]
  229. masks = paste_masks_in_image(masks, boxes, o_im_s)
  230. result[i]["masks"] = masks
  231. if "keypoints" in pred:
  232. keypoints = pred["keypoints"]
  233. keypoints = resize_keypoints(keypoints, im_s, o_im_s)
  234. result[i]["keypoints"] = keypoints
  235. return result
  236. def __repr__(self) -> str:
  237. format_string = f"{self.__class__.__name__}("
  238. _indent = "\n "
  239. format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
  240. format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
  241. format_string += "\n)"
  242. return format_string
  243. def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
  244. ratios = [
  245. torch.tensor(s, dtype=torch.float32, device=keypoints.device)
  246. / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
  247. for s, s_orig in zip(new_size, original_size)
  248. ]
  249. ratio_h, ratio_w = ratios
  250. resized_data = keypoints.clone()
  251. if torch._C._get_tracing_state():
  252. resized_data_0 = resized_data[:, :, 0] * ratio_w
  253. resized_data_1 = resized_data[:, :, 1] * ratio_h
  254. resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
  255. else:
  256. resized_data[..., 0] *= ratio_w
  257. resized_data[..., 1] *= ratio_h
  258. return resized_data
  259. def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
  260. ratios = [
  261. torch.tensor(s, dtype=torch.float32, device=boxes.device)
  262. / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
  263. for s, s_orig in zip(new_size, original_size)
  264. ]
  265. ratio_height, ratio_width = ratios
  266. xmin, ymin, xmax, ymax = boxes.unbind(1)
  267. xmin = xmin * ratio_width
  268. xmax = xmax * ratio_width
  269. ymin = ymin * ratio_height
  270. ymax = ymax * ratio_height
  271. return torch.stack((xmin, ymin, xmax, ymax), dim=1)