123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503 |
- from typing import List, Optional
- import torch
- from torch import nn, Tensor
- from torch.nn import functional as F
- from torchvision.prototype.models.depth.stereo.raft_stereo import grid_sample, make_coords_grid
- def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
- """Function to create a 2D Gaussian kernel."""
- x = torch.arange(kernel_size, dtype=torch.float32)
- y = torch.arange(kernel_size, dtype=torch.float32)
- x = x - (kernel_size - 1) / 2
- y = y - (kernel_size - 1) / 2
- x, y = torch.meshgrid(x, y)
- grid = (x**2 + y**2) / (2 * sigma**2)
- kernel = torch.exp(-grid)
- kernel = kernel / kernel.sum()
- return kernel
- def _sequence_loss_fn(
- flow_preds: List[Tensor],
- flow_gt: Tensor,
- valid_flow_mask: Optional[Tensor],
- gamma: Tensor,
- max_flow: int = 256,
- exclude_large: bool = False,
- weights: Optional[Tensor] = None,
- ):
- """Loss function defined over sequence of flow predictions"""
- torch._assert(
- gamma < 1,
- "sequence_loss: `gamma` must be lower than 1, but got {}".format(gamma),
- )
- if exclude_large:
- # exclude invalid pixels and extremely large diplacements
- flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
- if valid_flow_mask is not None:
- valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
- else:
- valid_flow_mask = flow_norm < max_flow
- if valid_flow_mask is not None:
- valid_flow_mask = valid_flow_mask.unsqueeze(1)
- flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
- abs_diff = (flow_preds - flow_gt).abs()
- if valid_flow_mask is not None:
- abs_diff = abs_diff * valid_flow_mask.unsqueeze(0)
- abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
- num_predictions = flow_preds.shape[0]
- # allocating on CPU and moving to device during run-time can force
- # an unwanted GPU synchronization that produces a large overhead
- if weights is None or len(weights) != num_predictions:
- weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
- flow_loss = (abs_diff * weights).sum()
- return flow_loss, weights
- class SequenceLoss(nn.Module):
- def __init__(self, gamma: float = 0.8, max_flow: int = 256, exclude_large_flows: bool = False) -> None:
- """
- Args:
- gamma: value for the exponential weighting of the loss across frames
- max_flow: maximum flow value to exclude
- exclude_large_flows: whether to exclude large flows
- """
- super().__init__()
- self.max_flow = max_flow
- self.excluding_large = exclude_large_flows
- self.register_buffer("gamma", torch.tensor([gamma]))
- # cache the scale factor for the loss
- self._weights = None
- def forward(self, flow_preds: List[Tensor], flow_gt: Tensor, valid_flow_mask: Optional[Tensor]) -> Tensor:
- """
- Args:
- flow_preds: list of flow predictions of shape (batch_size, C, H, W)
- flow_gt: ground truth flow of shape (batch_size, C, H, W)
- valid_flow_mask: mask of valid flow pixels of shape (batch_size, H, W)
- """
- loss, weights = _sequence_loss_fn(
- flow_preds, flow_gt, valid_flow_mask, self.gamma, self.max_flow, self.excluding_large, self._weights
- )
- self._weights = weights
- return loss
- def set_gamma(self, gamma: float) -> None:
- self.gamma.fill_(gamma)
- # reset the cached scale factor
- self._weights = None
- def _ssim_loss_fn(
- source: Tensor,
- reference: Tensor,
- kernel: Tensor,
- eps: float = 1e-8,
- c1: float = 0.01**2,
- c2: float = 0.03**2,
- use_padding: bool = False,
- ) -> Tensor:
- # ref: Algorithm section: https://en.wikipedia.org/wiki/Structural_similarity
- # ref: Alternative implementation: https://kornia.readthedocs.io/en/latest/_modules/kornia/metrics/ssim.html#ssim
- torch._assert(
- source.ndim == reference.ndim == 4,
- "SSIM: `source` and `reference` must be 4-dimensional tensors",
- )
- torch._assert(
- source.shape == reference.shape,
- "SSIM: `source` and `reference` must have the same shape, but got {} and {}".format(
- source.shape, reference.shape
- ),
- )
- B, C, H, W = source.shape
- kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(C, 1, 1, 1)
- if use_padding:
- pad_size = kernel.shape[2] // 2
- source = F.pad(source, (pad_size, pad_size, pad_size, pad_size), "reflect")
- reference = F.pad(reference, (pad_size, pad_size, pad_size, pad_size), "reflect")
- mu1 = F.conv2d(source, kernel, groups=C)
- mu2 = F.conv2d(reference, kernel, groups=C)
- mu1_sq = mu1.pow(2)
- mu2_sq = mu2.pow(2)
- mu1_mu2 = mu1 * mu2
- mu_img1_sq = F.conv2d(source.pow(2), kernel, groups=C)
- mu_img2_sq = F.conv2d(reference.pow(2), kernel, groups=C)
- mu_img1_mu2 = F.conv2d(source * reference, kernel, groups=C)
- sigma1_sq = mu_img1_sq - mu1_sq
- sigma2_sq = mu_img2_sq - mu2_sq
- sigma12 = mu_img1_mu2 - mu1_mu2
- numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
- denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
- ssim = numerator / (denominator + eps)
- # doing 1 - ssim because we want to maximize the ssim
- return 1 - ssim.mean(dim=(1, 2, 3))
- class SSIM(nn.Module):
- def __init__(
- self,
- kernel_size: int = 11,
- max_val: float = 1.0,
- sigma: float = 1.5,
- eps: float = 1e-12,
- use_padding: bool = True,
- ) -> None:
- """SSIM loss function.
- Args:
- kernel_size: size of the Gaussian kernel
- max_val: constant scaling factor
- sigma: sigma of the Gaussian kernel
- eps: constant for division by zero
- use_padding: whether to pad the input tensor such that we have a score for each pixel
- """
- super().__init__()
- self.kernel_size = kernel_size
- self.max_val = max_val
- self.sigma = sigma
- gaussian_kernel = make_gaussian_kernel(kernel_size, sigma)
- self.register_buffer("gaussian_kernel", gaussian_kernel)
- self.c1 = (0.01 * self.max_val) ** 2
- self.c2 = (0.03 * self.max_val) ** 2
- self.use_padding = use_padding
- self.eps = eps
- def forward(self, source: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
- """
- Args:
- source: source image of shape (batch_size, C, H, W)
- reference: reference image of shape (batch_size, C, H, W)
- Returns:
- SSIM loss of shape (batch_size,)
- """
- return _ssim_loss_fn(
- source,
- reference,
- kernel=self.gaussian_kernel,
- c1=self.c1,
- c2=self.c2,
- use_padding=self.use_padding,
- eps=self.eps,
- )
- def _smoothness_loss_fn(img_gx: Tensor, img_gy: Tensor, val_gx: Tensor, val_gy: Tensor):
- # ref: https://github.com/nianticlabs/monodepth2/blob/b676244e5a1ca55564eb5d16ab521a48f823af31/layers.py#L202
- torch._assert(
- img_gx.ndim >= 3,
- "smoothness_loss: `img_gx` must be at least 3-dimensional tensor of shape (..., C, H, W)",
- )
- torch._assert(
- img_gx.ndim == val_gx.ndim,
- "smoothness_loss: `img_gx` and `depth_gx` must have the same dimensionality, but got {} and {}".format(
- img_gx.ndim, val_gx.ndim
- ),
- )
- for idx in range(img_gx.ndim):
- torch._assert(
- (img_gx.shape[idx] == val_gx.shape[idx] or (img_gx.shape[idx] == 1 or val_gx.shape[idx] == 1)),
- "smoothness_loss: `img_gx` and `depth_gx` must have either the same shape or broadcastable shape, but got {} and {}".format(
- img_gx.shape, val_gx.shape
- ),
- )
- # -3 is channel dimension
- weights_x = torch.exp(-torch.mean(torch.abs(val_gx), axis=-3, keepdim=True))
- weights_y = torch.exp(-torch.mean(torch.abs(val_gy), axis=-3, keepdim=True))
- smoothness_x = img_gx * weights_x
- smoothness_y = img_gy * weights_y
- smoothness = (torch.abs(smoothness_x) + torch.abs(smoothness_y)).mean(axis=(-3, -2, -1))
- return smoothness
- class SmoothnessLoss(nn.Module):
- def __init__(self) -> None:
- super().__init__()
- def _x_gradient(self, img: Tensor) -> Tensor:
- if img.ndim > 4:
- original_shape = img.shape
- is_reshaped = True
- img = img.reshape(-1, *original_shape[-3:])
- else:
- is_reshaped = False
- padded = F.pad(img, (0, 1, 0, 0), mode="replicate")
- grad = padded[..., :, :-1] - padded[..., :, 1:]
- if is_reshaped:
- grad = grad.reshape(original_shape)
- return grad
- def _y_gradient(self, x: torch.Tensor) -> torch.Tensor:
- if x.ndim > 4:
- original_shape = x.shape
- is_reshaped = True
- x = x.reshape(-1, *original_shape[-3:])
- else:
- is_reshaped = False
- padded = F.pad(x, (0, 0, 0, 1), mode="replicate")
- grad = padded[..., :-1, :] - padded[..., 1:, :]
- if is_reshaped:
- grad = grad.reshape(original_shape)
- return grad
- def forward(self, images: Tensor, vals: Tensor) -> Tensor:
- """
- Args:
- images: tensor of shape (D1, D2, ..., DN, C, H, W)
- vals: tensor of shape (D1, D2, ..., DN, 1, H, W)
- Returns:
- smoothness loss of shape (D1, D2, ..., DN)
- """
- img_gx = self._x_gradient(images)
- img_gy = self._y_gradient(images)
- val_gx = self._x_gradient(vals)
- val_gy = self._y_gradient(vals)
- return _smoothness_loss_fn(img_gx, img_gy, val_gx, val_gy)
- def _flow_sequence_consistency_loss_fn(
- flow_preds: List[Tensor],
- gamma: float = 0.8,
- resize_factor: float = 0.25,
- rescale_factor: float = 0.25,
- rescale_mode: str = "bilinear",
- weights: Optional[Tensor] = None,
- ):
- """Loss function defined over sequence of flow predictions"""
- # Simplified version of ref: https://arxiv.org/pdf/2006.11242.pdf
- # In the original paper, an additional refinement network is used to refine a flow prediction.
- # Each step performed by the recurrent module in Raft or CREStereo is a refinement step using a delta_flow update.
- # which should be consistent with the previous step. In this implementation, we simplify the overall loss
- # term and ignore left-right consistency loss or photometric loss which can be treated separately.
- torch._assert(
- rescale_factor <= 1.0,
- "sequence_consistency_loss: `rescale_factor` must be less than or equal to 1, but got {}".format(
- rescale_factor
- ),
- )
- flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W)
- N, B, C, H, W = flow_preds.shape
- # rescale flow predictions to account for bilinear upsampling artifacts
- if rescale_factor:
- flow_preds = (
- F.interpolate(
- flow_preds.view(N * B, C, H, W), scale_factor=resize_factor, mode=rescale_mode, align_corners=True
- )
- ) * rescale_factor
- flow_preds = torch.stack(torch.chunk(flow_preds, N, dim=0), dim=0)
- # force the next prediction to be similar to the previous prediction
- abs_diff = (flow_preds[1:] - flow_preds[:-1]).square()
- abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
- num_predictions = flow_preds.shape[0] - 1 # because we are comparing differences
- if weights is None or len(weights) != num_predictions:
- weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)
- flow_loss = (abs_diff * weights).sum()
- return flow_loss, weights
- class FlowSequenceConsistencyLoss(nn.Module):
- def __init__(
- self,
- gamma: float = 0.8,
- resize_factor: float = 0.25,
- rescale_factor: float = 0.25,
- rescale_mode: str = "bilinear",
- ) -> None:
- super().__init__()
- self.gamma = gamma
- self.resize_factor = resize_factor
- self.rescale_factor = rescale_factor
- self.rescale_mode = rescale_mode
- self._weights = None
- def forward(self, flow_preds: List[Tensor]) -> Tensor:
- """
- Args:
- flow_preds: list of tensors of shape (batch_size, C, H, W)
- Returns:
- sequence consistency loss of shape (batch_size,)
- """
- loss, weights = _flow_sequence_consistency_loss_fn(
- flow_preds,
- gamma=self.gamma,
- resize_factor=self.resize_factor,
- rescale_factor=self.rescale_factor,
- rescale_mode=self.rescale_mode,
- weights=self._weights,
- )
- self._weights = weights
- return loss
- def set_gamma(self, gamma: float) -> None:
- self.gamma.fill_(gamma)
- # reset the cached scale factor
- self._weights = None
- def _psnr_loss_fn(source: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
- torch._assert(
- source.shape == target.shape,
- "psnr_loss: source and target must have the same shape, but got {} and {}".format(source.shape, target.shape),
- )
- # ref https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
- return 10 * torch.log10(max_val**2 / ((source - target).pow(2).mean(axis=(-3, -2, -1))))
- class PSNRLoss(nn.Module):
- def __init__(self, max_val: float = 256) -> None:
- """
- Args:
- max_val: maximum value of the input tensor. This refers to the maximum domain value of the input tensor.
- """
- super().__init__()
- self.max_val = max_val
- def forward(self, source: Tensor, target: Tensor) -> Tensor:
- """
- Args:
- source: tensor of shape (D1, D2, ..., DN, C, H, W)
- target: tensor of shape (D1, D2, ..., DN, C, H, W)
- Returns:
- psnr loss of shape (D1, D2, ..., DN)
- """
- # multiply by -1 as we want to maximize the psnr
- return -1 * _psnr_loss_fn(source, target, self.max_val)
- class FlowPhotoMetricLoss(nn.Module):
- def __init__(
- self,
- ssim_weight: float = 0.85,
- ssim_window_size: int = 11,
- ssim_max_val: float = 1.0,
- ssim_sigma: float = 1.5,
- ssim_eps: float = 1e-12,
- ssim_use_padding: bool = True,
- max_displacement_ratio: float = 0.15,
- ) -> None:
- super().__init__()
- self._ssim_loss = SSIM(
- kernel_size=ssim_window_size,
- max_val=ssim_max_val,
- sigma=ssim_sigma,
- eps=ssim_eps,
- use_padding=ssim_use_padding,
- )
- self._L1_weight = 1 - ssim_weight
- self._SSIM_weight = ssim_weight
- self._max_displacement_ratio = max_displacement_ratio
- def forward(
- self,
- source: Tensor,
- reference: Tensor,
- flow_pred: Tensor,
- valid_mask: Optional[Tensor] = None,
- ):
- """
- Args:
- source: tensor of shape (B, C, H, W)
- reference: tensor of shape (B, C, H, W)
- flow_pred: tensor of shape (B, 2, H, W)
- valid_mask: tensor of shape (B, H, W) or None
- Returns:
- photometric loss of shape
- """
- torch._assert(
- source.ndim == 4,
- "FlowPhotoMetricLoss: source must have 4 dimensions, but got {}".format(source.ndim),
- )
- torch._assert(
- reference.ndim == source.ndim,
- "FlowPhotoMetricLoss: source and other must have the same number of dimensions, but got {} and {}".format(
- source.ndim, reference.ndim
- ),
- )
- torch._assert(
- flow_pred.shape[1] == 2,
- "FlowPhotoMetricLoss: flow_pred must have 2 channels, but got {}".format(flow_pred.shape[1]),
- )
- torch._assert(
- flow_pred.ndim == 4,
- "FlowPhotoMetricLoss: flow_pred must have 4 dimensions, but got {}".format(flow_pred.ndim),
- )
- B, C, H, W = source.shape
- flow_channels = flow_pred.shape[1]
- max_displacements = []
- for dim in range(flow_channels):
- shape_index = -1 - dim
- max_displacements.append(int(self._max_displacement_ratio * source.shape[shape_index]))
- # mask out all pixels that have larger flow than the max flow allowed
- max_flow_mask = torch.logical_and(
- *[flow_pred[:, dim, :, :] < max_displacements[dim] for dim in range(flow_channels)]
- )
- if valid_mask is not None:
- valid_mask = torch.logical_and(valid_mask, max_flow_mask).unsqueeze(1)
- else:
- valid_mask = max_flow_mask.unsqueeze(1)
- grid = make_coords_grid(B, H, W, device=str(source.device))
- resampled_grids = grid - flow_pred
- resampled_grids = resampled_grids.permute(0, 2, 3, 1)
- resampled_source = grid_sample(reference, resampled_grids, mode="bilinear")
- # compute SSIM loss
- ssim_loss = self._ssim_loss(resampled_source * valid_mask, source * valid_mask)
- l1_loss = (resampled_source * valid_mask - source * valid_mask).abs().mean(axis=(-3, -2, -1))
- loss = self._L1_weight * l1_loss + self._SSIM_weight * ssim_loss
- return loss.mean()
|