transforms.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. import random
  2. from typing import Callable, List, Optional, Sequence, Tuple, Union
  3. import numpy as np
  4. import PIL.Image
  5. import torch
  6. import torchvision.transforms as T
  7. import torchvision.transforms.functional as F
  8. from torch import Tensor
  9. T_FLOW = Union[Tensor, np.ndarray, None]
  10. T_MASK = Union[Tensor, np.ndarray, None]
  11. T_STEREO_TENSOR = Tuple[Tensor, Tensor]
  12. T_COLOR_AUG_PARAM = Union[float, Tuple[float, float]]
  13. def rand_float_range(size: Sequence[int], low: float, high: float) -> Tensor:
  14. return (low - high) * torch.rand(size) + high
  15. class InterpolationStrategy:
  16. _valid_modes: List[str] = ["mixed", "bicubic", "bilinear"]
  17. def __init__(self, mode: str = "mixed") -> None:
  18. if mode not in self._valid_modes:
  19. raise ValueError(f"Invalid interpolation mode: {mode}. Valid modes are: {self._valid_modes}")
  20. if mode == "mixed":
  21. self.strategies = [F.InterpolationMode.BILINEAR, F.InterpolationMode.BICUBIC]
  22. elif mode == "bicubic":
  23. self.strategies = [F.InterpolationMode.BICUBIC]
  24. elif mode == "bilinear":
  25. self.strategies = [F.InterpolationMode.BILINEAR]
  26. def __call__(self) -> F.InterpolationMode:
  27. return random.choice(self.strategies)
  28. @classmethod
  29. def is_valid(mode: str) -> bool:
  30. return mode in InterpolationStrategy._valid_modes
  31. @property
  32. def valid_modes() -> List[str]:
  33. return InterpolationStrategy._valid_modes
  34. class ValidateModelInput(torch.nn.Module):
  35. # Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
  36. def forward(self, images: T_STEREO_TENSOR, disparities: T_FLOW, masks: T_MASK):
  37. if images[0].shape != images[1].shape:
  38. raise ValueError("img1 and img2 should have the same shape.")
  39. h, w = images[0].shape[-2:]
  40. if disparities[0] is not None and disparities[0].shape != (1, h, w):
  41. raise ValueError(f"disparities[0].shape should be (1, {h}, {w}) instead of {disparities[0].shape}")
  42. if masks[0] is not None:
  43. if masks[0].shape != (h, w):
  44. raise ValueError(f"masks[0].shape should be ({h}, {w}) instead of {masks[0].shape}")
  45. if masks[0].dtype != torch.bool:
  46. raise TypeError(f"masks[0] should be of dtype torch.bool instead of {masks[0].dtype}")
  47. return images, disparities, masks
  48. class ConvertToGrayscale(torch.nn.Module):
  49. def __init__(self) -> None:
  50. super().__init__()
  51. def forward(
  52. self,
  53. images: Tuple[PIL.Image.Image, PIL.Image.Image],
  54. disparities: Tuple[T_FLOW, T_FLOW],
  55. masks: Tuple[T_MASK, T_MASK],
  56. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  57. img_left = F.rgb_to_grayscale(images[0], num_output_channels=3)
  58. img_right = F.rgb_to_grayscale(images[1], num_output_channels=3)
  59. return (img_left, img_right), disparities, masks
  60. class MakeValidDisparityMask(torch.nn.Module):
  61. def __init__(self, max_disparity: Optional[int] = 256) -> None:
  62. super().__init__()
  63. self.max_disparity = max_disparity
  64. def forward(
  65. self,
  66. images: T_STEREO_TENSOR,
  67. disparities: Tuple[T_FLOW, T_FLOW],
  68. masks: Tuple[T_MASK, T_MASK],
  69. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  70. valid_masks = tuple(
  71. torch.ones(images[idx].shape[-2:], dtype=torch.bool, device=images[idx].device) if mask is None else mask
  72. for idx, mask in enumerate(masks)
  73. )
  74. valid_masks = tuple(
  75. torch.logical_and(mask, disparity > 0).squeeze(0) if disparity is not None else mask
  76. for mask, disparity in zip(valid_masks, disparities)
  77. )
  78. if self.max_disparity is not None:
  79. valid_masks = tuple(
  80. torch.logical_and(mask, disparity < self.max_disparity).squeeze(0) if disparity is not None else mask
  81. for mask, disparity in zip(valid_masks, disparities)
  82. )
  83. return images, disparities, valid_masks
  84. class ToGPU(torch.nn.Module):
  85. def __init__(self) -> None:
  86. super().__init__()
  87. def forward(
  88. self,
  89. images: T_STEREO_TENSOR,
  90. disparities: Tuple[T_FLOW, T_FLOW],
  91. masks: Tuple[T_MASK, T_MASK],
  92. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  93. dev_images = tuple(image.cuda() for image in images)
  94. dev_disparities = tuple(map(lambda x: x.cuda() if x is not None else None, disparities))
  95. dev_masks = tuple(map(lambda x: x.cuda() if x is not None else None, masks))
  96. return dev_images, dev_disparities, dev_masks
  97. class ConvertImageDtype(torch.nn.Module):
  98. def __init__(self, dtype: torch.dtype):
  99. super().__init__()
  100. self.dtype = dtype
  101. def forward(
  102. self,
  103. images: T_STEREO_TENSOR,
  104. disparities: Tuple[T_FLOW, T_FLOW],
  105. masks: Tuple[T_MASK, T_MASK],
  106. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  107. img_left = F.convert_image_dtype(images[0], dtype=self.dtype)
  108. img_right = F.convert_image_dtype(images[1], dtype=self.dtype)
  109. img_left = img_left.contiguous()
  110. img_right = img_right.contiguous()
  111. return (img_left, img_right), disparities, masks
  112. class Normalize(torch.nn.Module):
  113. def __init__(self, mean: List[float], std: List[float]) -> None:
  114. super().__init__()
  115. self.mean = mean
  116. self.std = std
  117. def forward(
  118. self,
  119. images: T_STEREO_TENSOR,
  120. disparities: Tuple[T_FLOW, T_FLOW],
  121. masks: Tuple[T_MASK, T_MASK],
  122. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  123. img_left = F.normalize(images[0], mean=self.mean, std=self.std)
  124. img_right = F.normalize(images[1], mean=self.mean, std=self.std)
  125. img_left = img_left.contiguous()
  126. img_right = img_right.contiguous()
  127. return (img_left, img_right), disparities, masks
  128. class ToTensor(torch.nn.Module):
  129. def forward(
  130. self,
  131. images: Tuple[PIL.Image.Image, PIL.Image.Image],
  132. disparities: Tuple[T_FLOW, T_FLOW],
  133. masks: Tuple[T_MASK, T_MASK],
  134. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  135. if images[0] is None:
  136. raise ValueError("img_left is None")
  137. if images[1] is None:
  138. raise ValueError("img_right is None")
  139. img_left = F.pil_to_tensor(images[0])
  140. img_right = F.pil_to_tensor(images[1])
  141. disparity_tensors = ()
  142. mask_tensors = ()
  143. for idx in range(2):
  144. disparity_tensors += (torch.from_numpy(disparities[idx]),) if disparities[idx] is not None else (None,)
  145. mask_tensors += (torch.from_numpy(masks[idx]),) if masks[idx] is not None else (None,)
  146. return (img_left, img_right), disparity_tensors, mask_tensors
  147. class AsymmetricColorJitter(T.ColorJitter):
  148. # p determines the probability of doing asymmetric vs symmetric color jittering
  149. def __init__(
  150. self,
  151. brightness: T_COLOR_AUG_PARAM = 0,
  152. contrast: T_COLOR_AUG_PARAM = 0,
  153. saturation: T_COLOR_AUG_PARAM = 0,
  154. hue: T_COLOR_AUG_PARAM = 0,
  155. p: float = 0.2,
  156. ):
  157. super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
  158. self.p = p
  159. def forward(
  160. self,
  161. images: T_STEREO_TENSOR,
  162. disparities: Tuple[T_FLOW, T_FLOW],
  163. masks: Tuple[T_MASK, T_MASK],
  164. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  165. if torch.rand(1) < self.p:
  166. # asymmetric: different transform for img1 and img2
  167. img_left = super().forward(images[0])
  168. img_right = super().forward(images[1])
  169. else:
  170. # symmetric: same transform for img1 and img2
  171. batch = torch.stack(images)
  172. batch = super().forward(batch)
  173. img_left, img_right = batch[0], batch[1]
  174. return (img_left, img_right), disparities, masks
  175. class AsymetricGammaAdjust(torch.nn.Module):
  176. def __init__(self, p: float, gamma_range: Tuple[float, float], gain: float = 1) -> None:
  177. super().__init__()
  178. self.gamma_range = gamma_range
  179. self.gain = gain
  180. self.p = p
  181. def forward(
  182. self,
  183. images: T_STEREO_TENSOR,
  184. disparities: Tuple[T_FLOW, T_FLOW],
  185. masks: Tuple[T_MASK, T_MASK],
  186. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  187. gamma = rand_float_range((1,), low=self.gamma_range[0], high=self.gamma_range[1]).item()
  188. if torch.rand(1) < self.p:
  189. # asymmetric: different transform for img1 and img2
  190. img_left = F.adjust_gamma(images[0], gamma, gain=self.gain)
  191. img_right = F.adjust_gamma(images[1], gamma, gain=self.gain)
  192. else:
  193. # symmetric: same transform for img1 and img2
  194. batch = torch.stack(images)
  195. batch = F.adjust_gamma(batch, gamma, gain=self.gain)
  196. img_left, img_right = batch[0], batch[1]
  197. return (img_left, img_right), disparities, masks
  198. class RandomErase(torch.nn.Module):
  199. # Produces multiple symmetric random erasures
  200. # these can be viewed as occlusions present in both camera views.
  201. # Similarly to Optical Flow occlusion prediction tasks, we mask these pixels in the disparity map
  202. def __init__(
  203. self,
  204. p: float = 0.5,
  205. erase_px_range: Tuple[int, int] = (50, 100),
  206. value: Union[Tensor, float] = 0,
  207. inplace: bool = False,
  208. max_erase: int = 2,
  209. ):
  210. super().__init__()
  211. self.min_px_erase = erase_px_range[0]
  212. self.max_px_erase = erase_px_range[1]
  213. if self.max_px_erase < 0:
  214. raise ValueError("erase_px_range[1] should be equal or greater than 0")
  215. if self.min_px_erase < 0:
  216. raise ValueError("erase_px_range[0] should be equal or greater than 0")
  217. if self.min_px_erase > self.max_px_erase:
  218. raise ValueError("erase_prx_range[0] should be equal or lower than erase_px_range[1]")
  219. self.p = p
  220. self.value = value
  221. self.inplace = inplace
  222. self.max_erase = max_erase
  223. def forward(
  224. self,
  225. images: T_STEREO_TENSOR,
  226. disparities: T_STEREO_TENSOR,
  227. masks: T_STEREO_TENSOR,
  228. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  229. if torch.rand(1) < self.p:
  230. return images, disparities, masks
  231. image_left, image_right = images
  232. mask_left, mask_right = masks
  233. for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
  234. y, x, h, w, v = self._get_params(image_left)
  235. image_right = F.erase(image_right, y, x, h, w, v, self.inplace)
  236. image_left = F.erase(image_left, y, x, h, w, v, self.inplace)
  237. # similarly to optical flow occlusion prediction, we consider
  238. # any erasure pixels that are in both images to be occluded therefore
  239. # we mark them as invalid
  240. if mask_left is not None:
  241. mask_left = F.erase(mask_left, y, x, h, w, False, self.inplace)
  242. if mask_right is not None:
  243. mask_right = F.erase(mask_right, y, x, h, w, False, self.inplace)
  244. return (image_left, image_right), disparities, (mask_left, mask_right)
  245. def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
  246. img_h, img_w = img.shape[-2:]
  247. crop_h, crop_w = (
  248. random.randint(self.min_px_erase, self.max_px_erase),
  249. random.randint(self.min_px_erase, self.max_px_erase),
  250. )
  251. crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
  252. return crop_y, crop_x, crop_h, crop_w, self.value
  253. class RandomOcclusion(torch.nn.Module):
  254. # This adds an occlusion in the right image
  255. # the occluded patch works as a patch erase where the erase value is the mean
  256. # of the pixels from the selected zone
  257. def __init__(self, p: float = 0.5, occlusion_px_range: Tuple[int, int] = (50, 100), inplace: bool = False):
  258. super().__init__()
  259. self.min_px_occlusion = occlusion_px_range[0]
  260. self.max_px_occlusion = occlusion_px_range[1]
  261. if self.max_px_occlusion < 0:
  262. raise ValueError("occlusion_px_range[1] should be greater or equal than 0")
  263. if self.min_px_occlusion < 0:
  264. raise ValueError("occlusion_px_range[0] should be greater or equal than 0")
  265. if self.min_px_occlusion > self.max_px_occlusion:
  266. raise ValueError("occlusion_px_range[0] should be lower than occlusion_px_range[1]")
  267. self.p = p
  268. self.inplace = inplace
  269. def forward(
  270. self,
  271. images: T_STEREO_TENSOR,
  272. disparities: T_STEREO_TENSOR,
  273. masks: T_STEREO_TENSOR,
  274. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  275. left_image, right_image = images
  276. if torch.rand(1) < self.p:
  277. return images, disparities, masks
  278. y, x, h, w, v = self._get_params(right_image)
  279. right_image = F.erase(right_image, y, x, h, w, v, self.inplace)
  280. return ((left_image, right_image), disparities, masks)
  281. def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
  282. img_h, img_w = img.shape[-2:]
  283. crop_h, crop_w = (
  284. random.randint(self.min_px_occlusion, self.max_px_occlusion),
  285. random.randint(self.min_px_occlusion, self.max_px_occlusion),
  286. )
  287. crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
  288. occlusion_value = img[..., crop_y : crop_y + crop_h, crop_x : crop_x + crop_w].mean(dim=(-2, -1), keepdim=True)
  289. return (crop_y, crop_x, crop_h, crop_w, occlusion_value)
  290. class RandomSpatialShift(torch.nn.Module):
  291. # This transform applies a vertical shift and a slight angle rotation and the same time
  292. def __init__(
  293. self, p: float = 0.5, max_angle: float = 0.1, max_px_shift: int = 2, interpolation_type: str = "bilinear"
  294. ) -> None:
  295. super().__init__()
  296. self.p = p
  297. self.max_angle = max_angle
  298. self.max_px_shift = max_px_shift
  299. self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
  300. def forward(
  301. self,
  302. images: T_STEREO_TENSOR,
  303. disparities: T_STEREO_TENSOR,
  304. masks: T_STEREO_TENSOR,
  305. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  306. # the transform is applied only on the right image
  307. # in order to mimic slight calibration issues
  308. img_left, img_right = images
  309. INTERP_MODE = self._interpolation_mode_strategy()
  310. if torch.rand(1) < self.p:
  311. # [0, 1] -> [-a, a]
  312. shift = rand_float_range((1,), low=-self.max_px_shift, high=self.max_px_shift).item()
  313. angle = rand_float_range((1,), low=-self.max_angle, high=self.max_angle).item()
  314. # sample center point for the rotation matrix
  315. y = torch.randint(size=(1,), low=0, high=img_right.shape[-2]).item()
  316. x = torch.randint(size=(1,), low=0, high=img_right.shape[-1]).item()
  317. # apply affine transformations
  318. img_right = F.affine(
  319. img_right,
  320. angle=angle,
  321. translate=[0, shift], # translation only on the y-axis
  322. center=[x, y],
  323. scale=1.0,
  324. shear=0.0,
  325. interpolation=INTERP_MODE,
  326. )
  327. return ((img_left, img_right), disparities, masks)
  328. class RandomHorizontalFlip(torch.nn.Module):
  329. def __init__(self, p: float = 0.5) -> None:
  330. super().__init__()
  331. self.p = p
  332. def forward(
  333. self,
  334. images: T_STEREO_TENSOR,
  335. disparities: Tuple[T_FLOW, T_FLOW],
  336. masks: Tuple[T_MASK, T_MASK],
  337. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  338. img_left, img_right = images
  339. dsp_left, dsp_right = disparities
  340. mask_left, mask_right = masks
  341. if dsp_right is not None and torch.rand(1) < self.p:
  342. img_left, img_right = F.hflip(img_left), F.hflip(img_right)
  343. dsp_left, dsp_right = F.hflip(dsp_left), F.hflip(dsp_right)
  344. if mask_left is not None and mask_right is not None:
  345. mask_left, mask_right = F.hflip(mask_left), F.hflip(mask_right)
  346. return ((img_right, img_left), (dsp_right, dsp_left), (mask_right, mask_left))
  347. return images, disparities, masks
  348. class Resize(torch.nn.Module):
  349. def __init__(self, resize_size: Tuple[int, ...], interpolation_type: str = "bilinear") -> None:
  350. super().__init__()
  351. self.resize_size = list(resize_size) # doing this to keep mypy happy
  352. self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
  353. def forward(
  354. self,
  355. images: T_STEREO_TENSOR,
  356. disparities: Tuple[T_FLOW, T_FLOW],
  357. masks: Tuple[T_MASK, T_MASK],
  358. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  359. resized_images = ()
  360. resized_disparities = ()
  361. resized_masks = ()
  362. INTERP_MODE = self._interpolation_mode_strategy()
  363. for img in images:
  364. # We hard-code antialias=False to preserve results after we changed
  365. # its default from None to True (see
  366. # https://github.com/pytorch/vision/pull/7160)
  367. # TODO: we could re-train the stereo models with antialias=True?
  368. resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)
  369. for dsp in disparities:
  370. if dsp is not None:
  371. # rescale disparity to match the new image size
  372. scale_x = self.resize_size[1] / dsp.shape[-1]
  373. resized_disparities += (F.resize(dsp, self.resize_size, interpolation=INTERP_MODE) * scale_x,)
  374. else:
  375. resized_disparities += (None,)
  376. for mask in masks:
  377. if mask is not None:
  378. resized_masks += (
  379. # we squeeze and unsqueeze because the API requires > 3D tensors
  380. F.resize(
  381. mask.unsqueeze(0),
  382. self.resize_size,
  383. interpolation=F.InterpolationMode.NEAREST,
  384. ).squeeze(0),
  385. )
  386. else:
  387. resized_masks += (None,)
  388. return resized_images, resized_disparities, resized_masks
  389. class RandomRescaleAndCrop(torch.nn.Module):
  390. # This transform will resize the input with a given proba, and then crop it.
  391. # These are the reversed operations of the built-in RandomResizedCrop,
  392. # although the order of the operations doesn't matter too much: resizing a
  393. # crop would give the same result as cropping a resized image, up to
  394. # interpolation artifact at the borders of the output.
  395. #
  396. # The reason we don't rely on RandomResizedCrop is because of a significant
  397. # difference in the parametrization of both transforms, in particular,
  398. # because of the way the random parameters are sampled in both transforms,
  399. # which leads to fairly different results (and different epe). For more details see
  400. # https://github.com/pytorch/vision/pull/5026/files#r762932579
  401. def __init__(
  402. self,
  403. crop_size: Tuple[int, int],
  404. scale_range: Tuple[float, float] = (-0.2, 0.5),
  405. rescale_prob: float = 0.8,
  406. scaling_type: str = "exponential",
  407. interpolation_type: str = "bilinear",
  408. ) -> None:
  409. super().__init__()
  410. self.crop_size = crop_size
  411. self.min_scale = scale_range[0]
  412. self.max_scale = scale_range[1]
  413. self.rescale_prob = rescale_prob
  414. self.scaling_type = scaling_type
  415. self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
  416. if self.scaling_type == "linear" and self.min_scale < 0:
  417. raise ValueError("min_scale must be >= 0 for linear scaling")
  418. def forward(
  419. self,
  420. images: T_STEREO_TENSOR,
  421. disparities: Tuple[T_FLOW, T_FLOW],
  422. masks: Tuple[T_MASK, T_MASK],
  423. ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
  424. img_left, img_right = images
  425. dsp_left, dsp_right = disparities
  426. mask_left, mask_right = masks
  427. INTERP_MODE = self._interpolation_mode_strategy()
  428. # randomly sample scale
  429. h, w = img_left.shape[-2:]
  430. # Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
  431. # It shouldn't matter much
  432. min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
  433. # exponential scaling will draw a random scale in (min_scale, max_scale) and then raise
  434. # 2 to the power of that random value. This final scale distribution will have a different
  435. # mean and variance than a uniform distribution. Note that a scale of 1 will result in
  436. # a rescaling of 2X the original size, whereas a scale of -1 will result in a rescaling
  437. # of 0.5X the original size.
  438. if self.scaling_type == "exponential":
  439. scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
  440. # linear scaling will draw a random scale in (min_scale, max_scale)
  441. elif self.scaling_type == "linear":
  442. scale = torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
  443. scale = max(scale, min_scale)
  444. new_h, new_w = round(h * scale), round(w * scale)
  445. if torch.rand(1).item() < self.rescale_prob:
  446. # rescale the images
  447. img_left = F.resize(img_left, size=(new_h, new_w), interpolation=INTERP_MODE)
  448. img_right = F.resize(img_right, size=(new_h, new_w), interpolation=INTERP_MODE)
  449. resized_masks, resized_disparities = (), ()
  450. for disparity, mask in zip(disparities, masks):
  451. if disparity is not None:
  452. if mask is None:
  453. resized_disparity = F.resize(disparity, size=(new_h, new_w), interpolation=INTERP_MODE)
  454. # rescale the disparity
  455. resized_disparity = (
  456. resized_disparity * torch.tensor([scale], device=resized_disparity.device)[:, None, None]
  457. )
  458. resized_mask = None
  459. else:
  460. resized_disparity, resized_mask = _resize_sparse_flow(
  461. disparity, mask, scale_x=scale, scale_y=scale
  462. )
  463. resized_masks += (resized_mask,)
  464. resized_disparities += (resized_disparity,)
  465. else:
  466. resized_disparities = disparities
  467. resized_masks = masks
  468. disparities = resized_disparities
  469. masks = resized_masks
  470. # Note: For sparse datasets (Kitti), the original code uses a "margin"
  471. # See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
  472. # We don't, not sure if it matters much
  473. y0 = torch.randint(0, img_left.shape[1] - self.crop_size[0], size=(1,)).item()
  474. x0 = torch.randint(0, img_right.shape[2] - self.crop_size[1], size=(1,)).item()
  475. img_left = F.crop(img_left, y0, x0, self.crop_size[0], self.crop_size[1])
  476. img_right = F.crop(img_right, y0, x0, self.crop_size[0], self.crop_size[1])
  477. if dsp_left is not None:
  478. dsp_left = F.crop(disparities[0], y0, x0, self.crop_size[0], self.crop_size[1])
  479. if dsp_right is not None:
  480. dsp_right = F.crop(disparities[1], y0, x0, self.crop_size[0], self.crop_size[1])
  481. cropped_masks = ()
  482. for mask in masks:
  483. if mask is not None:
  484. mask = F.crop(mask, y0, x0, self.crop_size[0], self.crop_size[1])
  485. cropped_masks += (mask,)
  486. return ((img_left, img_right), (dsp_left, dsp_right), cropped_masks)
  487. def _resize_sparse_flow(
  488. flow: Tensor, valid_flow_mask: Tensor, scale_x: float = 1.0, scale_y: float = 0.0
  489. ) -> Tuple[Tensor, Tensor]:
  490. # This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
  491. # There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
  492. # So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
  493. h, w = flow.shape[-2:]
  494. h_new = int(round(h * scale_y))
  495. w_new = int(round(w * scale_x))
  496. flow_new = torch.zeros(size=[1, h_new, w_new], dtype=flow.dtype)
  497. valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
  498. jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
  499. ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
  500. ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
  501. jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
  502. within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
  503. ii_valid = ii_valid[within_bounds_mask]
  504. jj_valid = jj_valid[within_bounds_mask]
  505. ii_valid_new = ii_valid_new[within_bounds_mask]
  506. jj_valid_new = jj_valid_new[within_bounds_mask]
  507. valid_flow_new = flow[:, ii_valid, jj_valid]
  508. valid_flow_new *= scale_x
  509. flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
  510. valid_new[ii_valid_new, jj_valid_new] = valid_flow_mask[ii_valid, jj_valid]
  511. return flow_new, valid_new.bool()
  512. class Compose(torch.nn.Module):
  513. def __init__(self, transforms: List[Callable]):
  514. super().__init__()
  515. self.transforms = transforms
  516. @torch.inference_mode()
  517. def forward(self, images, disparities, masks):
  518. for t in self.transforms:
  519. images, disparities, masks = t(images, disparities, masks)
  520. return images, disparities, masks