losses.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. from typing import List, Optional
  2. import torch
  3. from torch import nn, Tensor
  4. from torch.nn import functional as F
  5. from torchvision.prototype.models.depth.stereo.raft_stereo import grid_sample, make_coords_grid
  6. def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
  7. """Function to create a 2D Gaussian kernel."""
  8. x = torch.arange(kernel_size, dtype=torch.float32)
  9. y = torch.arange(kernel_size, dtype=torch.float32)
  10. x = x - (kernel_size - 1) / 2
  11. y = y - (kernel_size - 1) / 2
  12. x, y = torch.meshgrid(x, y)
  13. grid = (x**2 + y**2) / (2 * sigma**2)
  14. kernel = torch.exp(-grid)
  15. kernel = kernel / kernel.sum()
  16. return kernel
  17. def _sequence_loss_fn(
  18. flow_preds: List[Tensor],
  19. flow_gt: Tensor,
  20. valid_flow_mask: Optional[Tensor],
  21. gamma: Tensor,
  22. max_flow: int = 256,
  23. exclude_large: bool = False,
  24. weights: Optional[Tensor] = None,
  25. ):
  26. """Loss function defined over sequence of flow predictions"""
  27. torch._assert(
  28. gamma < 1,
  29. "sequence_loss: `gamma` must be lower than 1, but got {}".format(gamma),
  30. )
  31. if exclude_large:
  32. # exclude invalid pixels and extremely large diplacements
  33. flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
  34. if valid_flow_mask is not None:
  35. valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
  36. else:
  37. valid_flow_mask = flow_norm < max_flow
  38. if valid_flow_mask is not None:
  39. valid_flow_mask = valid_flow_mask.unsqueeze(1)
  40. flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
  41. abs_diff = (flow_preds - flow_gt).abs()
  42. if valid_flow_mask is not None:
  43. abs_diff = abs_diff * valid_flow_mask.unsqueeze(0)
  44. abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
  45. num_predictions = flow_preds.shape[0]
  46. # allocating on CPU and moving to device during run-time can force
  47. # an unwanted GPU synchronization that produces a large overhead
  48. if weights is None or len(weights) != num_predictions:
  49. weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
  50. flow_loss = (abs_diff * weights).sum()
  51. return flow_loss, weights
  52. class SequenceLoss(nn.Module):
  53. def __init__(self, gamma: float = 0.8, max_flow: int = 256, exclude_large_flows: bool = False) -> None:
  54. """
  55. Args:
  56. gamma: value for the exponential weighting of the loss across frames
  57. max_flow: maximum flow value to exclude
  58. exclude_large_flows: whether to exclude large flows
  59. """
  60. super().__init__()
  61. self.max_flow = max_flow
  62. self.excluding_large = exclude_large_flows
  63. self.register_buffer("gamma", torch.tensor([gamma]))
  64. # cache the scale factor for the loss
  65. self._weights = None
  66. def forward(self, flow_preds: List[Tensor], flow_gt: Tensor, valid_flow_mask: Optional[Tensor]) -> Tensor:
  67. """
  68. Args:
  69. flow_preds: list of flow predictions of shape (batch_size, C, H, W)
  70. flow_gt: ground truth flow of shape (batch_size, C, H, W)
  71. valid_flow_mask: mask of valid flow pixels of shape (batch_size, H, W)
  72. """
  73. loss, weights = _sequence_loss_fn(
  74. flow_preds, flow_gt, valid_flow_mask, self.gamma, self.max_flow, self.excluding_large, self._weights
  75. )
  76. self._weights = weights
  77. return loss
  78. def set_gamma(self, gamma: float) -> None:
  79. self.gamma.fill_(gamma)
  80. # reset the cached scale factor
  81. self._weights = None
  82. def _ssim_loss_fn(
  83. source: Tensor,
  84. reference: Tensor,
  85. kernel: Tensor,
  86. eps: float = 1e-8,
  87. c1: float = 0.01**2,
  88. c2: float = 0.03**2,
  89. use_padding: bool = False,
  90. ) -> Tensor:
  91. # ref: Algorithm section: https://en.wikipedia.org/wiki/Structural_similarity
  92. # ref: Alternative implementation: https://kornia.readthedocs.io/en/latest/_modules/kornia/metrics/ssim.html#ssim
  93. torch._assert(
  94. source.ndim == reference.ndim == 4,
  95. "SSIM: `source` and `reference` must be 4-dimensional tensors",
  96. )
  97. torch._assert(
  98. source.shape == reference.shape,
  99. "SSIM: `source` and `reference` must have the same shape, but got {} and {}".format(
  100. source.shape, reference.shape
  101. ),
  102. )
  103. B, C, H, W = source.shape
  104. kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(C, 1, 1, 1)
  105. if use_padding:
  106. pad_size = kernel.shape[2] // 2
  107. source = F.pad(source, (pad_size, pad_size, pad_size, pad_size), "reflect")
  108. reference = F.pad(reference, (pad_size, pad_size, pad_size, pad_size), "reflect")
  109. mu1 = F.conv2d(source, kernel, groups=C)
  110. mu2 = F.conv2d(reference, kernel, groups=C)
  111. mu1_sq = mu1.pow(2)
  112. mu2_sq = mu2.pow(2)
  113. mu1_mu2 = mu1 * mu2
  114. mu_img1_sq = F.conv2d(source.pow(2), kernel, groups=C)
  115. mu_img2_sq = F.conv2d(reference.pow(2), kernel, groups=C)
  116. mu_img1_mu2 = F.conv2d(source * reference, kernel, groups=C)
  117. sigma1_sq = mu_img1_sq - mu1_sq
  118. sigma2_sq = mu_img2_sq - mu2_sq
  119. sigma12 = mu_img1_mu2 - mu1_mu2
  120. numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
  121. denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
  122. ssim = numerator / (denominator + eps)
  123. # doing 1 - ssim because we want to maximize the ssim
  124. return 1 - ssim.mean(dim=(1, 2, 3))
  125. class SSIM(nn.Module):
  126. def __init__(
  127. self,
  128. kernel_size: int = 11,
  129. max_val: float = 1.0,
  130. sigma: float = 1.5,
  131. eps: float = 1e-12,
  132. use_padding: bool = True,
  133. ) -> None:
  134. """SSIM loss function.
  135. Args:
  136. kernel_size: size of the Gaussian kernel
  137. max_val: constant scaling factor
  138. sigma: sigma of the Gaussian kernel
  139. eps: constant for division by zero
  140. use_padding: whether to pad the input tensor such that we have a score for each pixel
  141. """
  142. super().__init__()
  143. self.kernel_size = kernel_size
  144. self.max_val = max_val
  145. self.sigma = sigma
  146. gaussian_kernel = make_gaussian_kernel(kernel_size, sigma)
  147. self.register_buffer("gaussian_kernel", gaussian_kernel)
  148. self.c1 = (0.01 * self.max_val) ** 2
  149. self.c2 = (0.03 * self.max_val) ** 2
  150. self.use_padding = use_padding
  151. self.eps = eps
  152. def forward(self, source: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
  153. """
  154. Args:
  155. source: source image of shape (batch_size, C, H, W)
  156. reference: reference image of shape (batch_size, C, H, W)
  157. Returns:
  158. SSIM loss of shape (batch_size,)
  159. """
  160. return _ssim_loss_fn(
  161. source,
  162. reference,
  163. kernel=self.gaussian_kernel,
  164. c1=self.c1,
  165. c2=self.c2,
  166. use_padding=self.use_padding,
  167. eps=self.eps,
  168. )
  169. def _smoothness_loss_fn(img_gx: Tensor, img_gy: Tensor, val_gx: Tensor, val_gy: Tensor):
  170. # ref: https://github.com/nianticlabs/monodepth2/blob/b676244e5a1ca55564eb5d16ab521a48f823af31/layers.py#L202
  171. torch._assert(
  172. img_gx.ndim >= 3,
  173. "smoothness_loss: `img_gx` must be at least 3-dimensional tensor of shape (..., C, H, W)",
  174. )
  175. torch._assert(
  176. img_gx.ndim == val_gx.ndim,
  177. "smoothness_loss: `img_gx` and `depth_gx` must have the same dimensionality, but got {} and {}".format(
  178. img_gx.ndim, val_gx.ndim
  179. ),
  180. )
  181. for idx in range(img_gx.ndim):
  182. torch._assert(
  183. (img_gx.shape[idx] == val_gx.shape[idx] or (img_gx.shape[idx] == 1 or val_gx.shape[idx] == 1)),
  184. "smoothness_loss: `img_gx` and `depth_gx` must have either the same shape or broadcastable shape, but got {} and {}".format(
  185. img_gx.shape, val_gx.shape
  186. ),
  187. )
  188. # -3 is channel dimension
  189. weights_x = torch.exp(-torch.mean(torch.abs(val_gx), axis=-3, keepdim=True))
  190. weights_y = torch.exp(-torch.mean(torch.abs(val_gy), axis=-3, keepdim=True))
  191. smoothness_x = img_gx * weights_x
  192. smoothness_y = img_gy * weights_y
  193. smoothness = (torch.abs(smoothness_x) + torch.abs(smoothness_y)).mean(axis=(-3, -2, -1))
  194. return smoothness
  195. class SmoothnessLoss(nn.Module):
  196. def __init__(self) -> None:
  197. super().__init__()
  198. def _x_gradient(self, img: Tensor) -> Tensor:
  199. if img.ndim > 4:
  200. original_shape = img.shape
  201. is_reshaped = True
  202. img = img.reshape(-1, *original_shape[-3:])
  203. else:
  204. is_reshaped = False
  205. padded = F.pad(img, (0, 1, 0, 0), mode="replicate")
  206. grad = padded[..., :, :-1] - padded[..., :, 1:]
  207. if is_reshaped:
  208. grad = grad.reshape(original_shape)
  209. return grad
  210. def _y_gradient(self, x: torch.Tensor) -> torch.Tensor:
  211. if x.ndim > 4:
  212. original_shape = x.shape
  213. is_reshaped = True
  214. x = x.reshape(-1, *original_shape[-3:])
  215. else:
  216. is_reshaped = False
  217. padded = F.pad(x, (0, 0, 0, 1), mode="replicate")
  218. grad = padded[..., :-1, :] - padded[..., 1:, :]
  219. if is_reshaped:
  220. grad = grad.reshape(original_shape)
  221. return grad
  222. def forward(self, images: Tensor, vals: Tensor) -> Tensor:
  223. """
  224. Args:
  225. images: tensor of shape (D1, D2, ..., DN, C, H, W)
  226. vals: tensor of shape (D1, D2, ..., DN, 1, H, W)
  227. Returns:
  228. smoothness loss of shape (D1, D2, ..., DN)
  229. """
  230. img_gx = self._x_gradient(images)
  231. img_gy = self._y_gradient(images)
  232. val_gx = self._x_gradient(vals)
  233. val_gy = self._y_gradient(vals)
  234. return _smoothness_loss_fn(img_gx, img_gy, val_gx, val_gy)
  235. def _flow_sequence_consistency_loss_fn(
  236. flow_preds: List[Tensor],
  237. gamma: float = 0.8,
  238. resize_factor: float = 0.25,
  239. rescale_factor: float = 0.25,
  240. rescale_mode: str = "bilinear",
  241. weights: Optional[Tensor] = None,
  242. ):
  243. """Loss function defined over sequence of flow predictions"""
  244. # Simplified version of ref: https://arxiv.org/pdf/2006.11242.pdf
  245. # In the original paper, an additional refinement network is used to refine a flow prediction.
  246. # Each step performed by the recurrent module in Raft or CREStereo is a refinement step using a delta_flow update.
  247. # which should be consistent with the previous step. In this implementation, we simplify the overall loss
  248. # term and ignore left-right consistency loss or photometric loss which can be treated separately.
  249. torch._assert(
  250. rescale_factor <= 1.0,
  251. "sequence_consistency_loss: `rescale_factor` must be less than or equal to 1, but got {}".format(
  252. rescale_factor
  253. ),
  254. )
  255. flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
  256. N, B, C, H, W = flow_preds.shape
  257. # rescale flow predictions to account for bilinear upsampling artifacts
  258. if rescale_factor:
  259. flow_preds = (
  260. F.interpolate(
  261. flow_preds.view(N * B, C, H, W), scale_factor=resize_factor, mode=rescale_mode, align_corners=True
  262. )
  263. ) * rescale_factor
  264. flow_preds = torch.stack(torch.chunk(flow_preds, N, dim=0), dim=0)
  265. # force the next prediction to be similar to the previous prediction
  266. abs_diff = (flow_preds[1:] - flow_preds[:-1]).square()
  267. abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
  268. num_predictions = flow_preds.shape[0] - 1 # because we are comparing differences
  269. if weights is None or len(weights) != num_predictions:
  270. weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
  271. flow_loss = (abs_diff * weights).sum()
  272. return flow_loss, weights
  273. class FlowSequenceConsistencyLoss(nn.Module):
  274. def __init__(
  275. self,
  276. gamma: float = 0.8,
  277. resize_factor: float = 0.25,
  278. rescale_factor: float = 0.25,
  279. rescale_mode: str = "bilinear",
  280. ) -> None:
  281. super().__init__()
  282. self.gamma = gamma
  283. self.resize_factor = resize_factor
  284. self.rescale_factor = rescale_factor
  285. self.rescale_mode = rescale_mode
  286. self._weights = None
  287. def forward(self, flow_preds: List[Tensor]) -> Tensor:
  288. """
  289. Args:
  290. flow_preds: list of tensors of shape (batch_size, C, H, W)
  291. Returns:
  292. sequence consistency loss of shape (batch_size,)
  293. """
  294. loss, weights = _flow_sequence_consistency_loss_fn(
  295. flow_preds,
  296. gamma=self.gamma,
  297. resize_factor=self.resize_factor,
  298. rescale_factor=self.rescale_factor,
  299. rescale_mode=self.rescale_mode,
  300. weights=self._weights,
  301. )
  302. self._weights = weights
  303. return loss
  304. def set_gamma(self, gamma: float) -> None:
  305. self.gamma.fill_(gamma)
  306. # reset the cached scale factor
  307. self._weights = None
  308. def _psnr_loss_fn(source: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
  309. torch._assert(
  310. source.shape == target.shape,
  311. "psnr_loss: source and target must have the same shape, but got {} and {}".format(source.shape, target.shape),
  312. )
  313. # ref https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
  314. return 10 * torch.log10(max_val**2 / ((source - target).pow(2).mean(axis=(-3, -2, -1))))
  315. class PSNRLoss(nn.Module):
  316. def __init__(self, max_val: float = 256) -> None:
  317. """
  318. Args:
  319. max_val: maximum value of the input tensor. This refers to the maximum domain value of the input tensor.
  320. """
  321. super().__init__()
  322. self.max_val = max_val
  323. def forward(self, source: Tensor, target: Tensor) -> Tensor:
  324. """
  325. Args:
  326. source: tensor of shape (D1, D2, ..., DN, C, H, W)
  327. target: tensor of shape (D1, D2, ..., DN, C, H, W)
  328. Returns:
  329. psnr loss of shape (D1, D2, ..., DN)
  330. """
  331. # multiply by -1 as we want to maximize the psnr
  332. return -1 * _psnr_loss_fn(source, target, self.max_val)
  333. class FlowPhotoMetricLoss(nn.Module):
  334. def __init__(
  335. self,
  336. ssim_weight: float = 0.85,
  337. ssim_window_size: int = 11,
  338. ssim_max_val: float = 1.0,
  339. ssim_sigma: float = 1.5,
  340. ssim_eps: float = 1e-12,
  341. ssim_use_padding: bool = True,
  342. max_displacement_ratio: float = 0.15,
  343. ) -> None:
  344. super().__init__()
  345. self._ssim_loss = SSIM(
  346. kernel_size=ssim_window_size,
  347. max_val=ssim_max_val,
  348. sigma=ssim_sigma,
  349. eps=ssim_eps,
  350. use_padding=ssim_use_padding,
  351. )
  352. self._L1_weight = 1 - ssim_weight
  353. self._SSIM_weight = ssim_weight
  354. self._max_displacement_ratio = max_displacement_ratio
  355. def forward(
  356. self,
  357. source: Tensor,
  358. reference: Tensor,
  359. flow_pred: Tensor,
  360. valid_mask: Optional[Tensor] = None,
  361. ):
  362. """
  363. Args:
  364. source: tensor of shape (B, C, H, W)
  365. reference: tensor of shape (B, C, H, W)
  366. flow_pred: tensor of shape (B, 2, H, W)
  367. valid_mask: tensor of shape (B, H, W) or None
  368. Returns:
  369. photometric loss of shape
  370. """
  371. torch._assert(
  372. source.ndim == 4,
  373. "FlowPhotoMetricLoss: source must have 4 dimensions, but got {}".format(source.ndim),
  374. )
  375. torch._assert(
  376. reference.ndim == source.ndim,
  377. "FlowPhotoMetricLoss: source and other must have the same number of dimensions, but got {} and {}".format(
  378. source.ndim, reference.ndim
  379. ),
  380. )
  381. torch._assert(
  382. flow_pred.shape[1] == 2,
  383. "FlowPhotoMetricLoss: flow_pred must have 2 channels, but got {}".format(flow_pred.shape[1]),
  384. )
  385. torch._assert(
  386. flow_pred.ndim == 4,
  387. "FlowPhotoMetricLoss: flow_pred must have 4 dimensions, but got {}".format(flow_pred.ndim),
  388. )
  389. B, C, H, W = source.shape
  390. flow_channels = flow_pred.shape[1]
  391. max_displacements = []
  392. for dim in range(flow_channels):
  393. shape_index = -1 - dim
  394. max_displacements.append(int(self._max_displacement_ratio * source.shape[shape_index]))
  395. # mask out all pixels that have larger flow than the max flow allowed
  396. max_flow_mask = torch.logical_and(
  397. *[flow_pred[:, dim, :, :] < max_displacements[dim] for dim in range(flow_channels)]
  398. )
  399. if valid_mask is not None:
  400. valid_mask = torch.logical_and(valid_mask, max_flow_mask).unsqueeze(1)
  401. else:
  402. valid_mask = max_flow_mask.unsqueeze(1)
  403. grid = make_coords_grid(B, H, W, device=str(source.device))
  404. resampled_grids = grid - flow_pred
  405. resampled_grids = resampled_grids.permute(0, 2, 3, 1)
  406. resampled_source = grid_sample(reference, resampled_grids, mode="bilinear")
  407. # compute SSIM loss
  408. ssim_loss = self._ssim_loss(resampled_source * valid_mask, source * valid_mask)
  409. l1_loss = (resampled_source * valid_mask - source * valid_mask).abs().mean(axis=(-3, -2, -1))
  410. loss = self._L1_weight * l1_loss + self._SSIM_weight * ssim_loss
  411. return loss.mean()