123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650 |
- import random
- from typing import Callable, List, Optional, Sequence, Tuple, Union
- import numpy as np
- import PIL.Image
- import torch
- import torchvision.transforms as T
- import torchvision.transforms.functional as F
- from torch import Tensor
- T_FLOW = Union[Tensor, np.ndarray, None]
- T_MASK = Union[Tensor, np.ndarray, None]
- T_STEREO_TENSOR = Tuple[Tensor, Tensor]
- T_COLOR_AUG_PARAM = Union[float, Tuple[float, float]]
- def rand_float_range(size: Sequence[int], low: float, high: float) -> Tensor:
- return (low - high) * torch.rand(size) + high
- class InterpolationStrategy:
- _valid_modes: List[str] = ["mixed", "bicubic", "bilinear"]
- def __init__(self, mode: str = "mixed") -> None:
- if mode not in self._valid_modes:
- raise ValueError(f"Invalid interpolation mode: {mode}. Valid modes are: {self._valid_modes}")
- if mode == "mixed":
- self.strategies = [F.InterpolationMode.BILINEAR, F.InterpolationMode.BICUBIC]
- elif mode == "bicubic":
- self.strategies = [F.InterpolationMode.BICUBIC]
- elif mode == "bilinear":
- self.strategies = [F.InterpolationMode.BILINEAR]
- def __call__(self) -> F.InterpolationMode:
- return random.choice(self.strategies)
- @classmethod
- def is_valid(mode: str) -> bool:
- return mode in InterpolationStrategy._valid_modes
- @property
- def valid_modes() -> List[str]:
- return InterpolationStrategy._valid_modes
- class ValidateModelInput(torch.nn.Module):
- # Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
- def forward(self, images: T_STEREO_TENSOR, disparities: T_FLOW, masks: T_MASK):
- if images[0].shape != images[1].shape:
- raise ValueError("img1 and img2 should have the same shape.")
- h, w = images[0].shape[-2:]
- if disparities[0] is not None and disparities[0].shape != (1, h, w):
- raise ValueError(f"disparities[0].shape should be (1, {h}, {w}) instead of {disparities[0].shape}")
- if masks[0] is not None:
- if masks[0].shape != (h, w):
- raise ValueError(f"masks[0].shape should be ({h}, {w}) instead of {masks[0].shape}")
- if masks[0].dtype != torch.bool:
- raise TypeError(f"masks[0] should be of dtype torch.bool instead of {masks[0].dtype}")
- return images, disparities, masks
- class ConvertToGrayscale(torch.nn.Module):
- def __init__(self) -> None:
- super().__init__()
- def forward(
- self,
- images: Tuple[PIL.Image.Image, PIL.Image.Image],
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- img_left = F.rgb_to_grayscale(images[0], num_output_channels=3)
- img_right = F.rgb_to_grayscale(images[1], num_output_channels=3)
- return (img_left, img_right), disparities, masks
- class MakeValidDisparityMask(torch.nn.Module):
- def __init__(self, max_disparity: Optional[int] = 256) -> None:
- super().__init__()
- self.max_disparity = max_disparity
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- valid_masks = tuple(
- torch.ones(images[idx].shape[-2:], dtype=torch.bool, device=images[idx].device) if mask is None else mask
- for idx, mask in enumerate(masks)
- )
- valid_masks = tuple(
- torch.logical_and(mask, disparity > 0).squeeze(0) if disparity is not None else mask
- for mask, disparity in zip(valid_masks, disparities)
- )
- if self.max_disparity is not None:
- valid_masks = tuple(
- torch.logical_and(mask, disparity < self.max_disparity).squeeze(0) if disparity is not None else mask
- for mask, disparity in zip(valid_masks, disparities)
- )
- return images, disparities, valid_masks
- class ToGPU(torch.nn.Module):
- def __init__(self) -> None:
- super().__init__()
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- dev_images = tuple(image.cuda() for image in images)
- dev_disparities = tuple(map(lambda x: x.cuda() if x is not None else None, disparities))
- dev_masks = tuple(map(lambda x: x.cuda() if x is not None else None, masks))
- return dev_images, dev_disparities, dev_masks
- class ConvertImageDtype(torch.nn.Module):
- def __init__(self, dtype: torch.dtype):
- super().__init__()
- self.dtype = dtype
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- img_left = F.convert_image_dtype(images[0], dtype=self.dtype)
- img_right = F.convert_image_dtype(images[1], dtype=self.dtype)
- img_left = img_left.contiguous()
- img_right = img_right.contiguous()
- return (img_left, img_right), disparities, masks
- class Normalize(torch.nn.Module):
- def __init__(self, mean: List[float], std: List[float]) -> None:
- super().__init__()
- self.mean = mean
- self.std = std
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- img_left = F.normalize(images[0], mean=self.mean, std=self.std)
- img_right = F.normalize(images[1], mean=self.mean, std=self.std)
- img_left = img_left.contiguous()
- img_right = img_right.contiguous()
- return (img_left, img_right), disparities, masks
- class ToTensor(torch.nn.Module):
- def forward(
- self,
- images: Tuple[PIL.Image.Image, PIL.Image.Image],
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- if images[0] is None:
- raise ValueError("img_left is None")
- if images[1] is None:
- raise ValueError("img_right is None")
- img_left = F.pil_to_tensor(images[0])
- img_right = F.pil_to_tensor(images[1])
- disparity_tensors = ()
- mask_tensors = ()
- for idx in range(2):
- disparity_tensors += (torch.from_numpy(disparities[idx]),) if disparities[idx] is not None else (None,)
- mask_tensors += (torch.from_numpy(masks[idx]),) if masks[idx] is not None else (None,)
- return (img_left, img_right), disparity_tensors, mask_tensors
- class AsymmetricColorJitter(T.ColorJitter):
- # p determines the probability of doing asymmetric vs symmetric color jittering
- def __init__(
- self,
- brightness: T_COLOR_AUG_PARAM = 0,
- contrast: T_COLOR_AUG_PARAM = 0,
- saturation: T_COLOR_AUG_PARAM = 0,
- hue: T_COLOR_AUG_PARAM = 0,
- p: float = 0.2,
- ):
- super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
- self.p = p
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- if torch.rand(1) < self.p:
- # asymmetric: different transform for img1 and img2
- img_left = super().forward(images[0])
- img_right = super().forward(images[1])
- else:
- # symmetric: same transform for img1 and img2
- batch = torch.stack(images)
- batch = super().forward(batch)
- img_left, img_right = batch[0], batch[1]
- return (img_left, img_right), disparities, masks
- class AsymetricGammaAdjust(torch.nn.Module):
- def __init__(self, p: float, gamma_range: Tuple[float, float], gain: float = 1) -> None:
- super().__init__()
- self.gamma_range = gamma_range
- self.gain = gain
- self.p = p
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- gamma = rand_float_range((1,), low=self.gamma_range[0], high=self.gamma_range[1]).item()
- if torch.rand(1) < self.p:
- # asymmetric: different transform for img1 and img2
- img_left = F.adjust_gamma(images[0], gamma, gain=self.gain)
- img_right = F.adjust_gamma(images[1], gamma, gain=self.gain)
- else:
- # symmetric: same transform for img1 and img2
- batch = torch.stack(images)
- batch = F.adjust_gamma(batch, gamma, gain=self.gain)
- img_left, img_right = batch[0], batch[1]
- return (img_left, img_right), disparities, masks
- class RandomErase(torch.nn.Module):
- # Produces multiple symmetric random erasures
- # these can be viewed as occlusions present in both camera views.
- # Similarly to Optical Flow occlusion prediction tasks, we mask these pixels in the disparity map
- def __init__(
- self,
- p: float = 0.5,
- erase_px_range: Tuple[int, int] = (50, 100),
- value: Union[Tensor, float] = 0,
- inplace: bool = False,
- max_erase: int = 2,
- ):
- super().__init__()
- self.min_px_erase = erase_px_range[0]
- self.max_px_erase = erase_px_range[1]
- if self.max_px_erase < 0:
- raise ValueError("erase_px_range[1] should be equal or greater than 0")
- if self.min_px_erase < 0:
- raise ValueError("erase_px_range[0] should be equal or greater than 0")
- if self.min_px_erase > self.max_px_erase:
- raise ValueError("erase_prx_range[0] should be equal or lower than erase_px_range[1]")
- self.p = p
- self.value = value
- self.inplace = inplace
- self.max_erase = max_erase
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: T_STEREO_TENSOR,
- masks: T_STEREO_TENSOR,
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- if torch.rand(1) < self.p:
- return images, disparities, masks
- image_left, image_right = images
- mask_left, mask_right = masks
- for _ in range(torch.randint(self.max_erase, size=(1,)).item()):
- y, x, h, w, v = self._get_params(image_left)
- image_right = F.erase(image_right, y, x, h, w, v, self.inplace)
- image_left = F.erase(image_left, y, x, h, w, v, self.inplace)
- # similarly to optical flow occlusion prediction, we consider
- # any erasure pixels that are in both images to be occluded therefore
- # we mark them as invalid
- if mask_left is not None:
- mask_left = F.erase(mask_left, y, x, h, w, False, self.inplace)
- if mask_right is not None:
- mask_right = F.erase(mask_right, y, x, h, w, False, self.inplace)
- return (image_left, image_right), disparities, (mask_left, mask_right)
- def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
- img_h, img_w = img.shape[-2:]
- crop_h, crop_w = (
- random.randint(self.min_px_erase, self.max_px_erase),
- random.randint(self.min_px_erase, self.max_px_erase),
- )
- crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
- return crop_y, crop_x, crop_h, crop_w, self.value
- class RandomOcclusion(torch.nn.Module):
- # This adds an occlusion in the right image
- # the occluded patch works as a patch erase where the erase value is the mean
- # of the pixels from the selected zone
- def __init__(self, p: float = 0.5, occlusion_px_range: Tuple[int, int] = (50, 100), inplace: bool = False):
- super().__init__()
- self.min_px_occlusion = occlusion_px_range[0]
- self.max_px_occlusion = occlusion_px_range[1]
- if self.max_px_occlusion < 0:
- raise ValueError("occlusion_px_range[1] should be greater or equal than 0")
- if self.min_px_occlusion < 0:
- raise ValueError("occlusion_px_range[0] should be greater or equal than 0")
- if self.min_px_occlusion > self.max_px_occlusion:
- raise ValueError("occlusion_px_range[0] should be lower than occlusion_px_range[1]")
- self.p = p
- self.inplace = inplace
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: T_STEREO_TENSOR,
- masks: T_STEREO_TENSOR,
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- left_image, right_image = images
- if torch.rand(1) < self.p:
- return images, disparities, masks
- y, x, h, w, v = self._get_params(right_image)
- right_image = F.erase(right_image, y, x, h, w, v, self.inplace)
- return ((left_image, right_image), disparities, masks)
- def _get_params(self, img: torch.Tensor) -> Tuple[int, int, int, int, float]:
- img_h, img_w = img.shape[-2:]
- crop_h, crop_w = (
- random.randint(self.min_px_occlusion, self.max_px_occlusion),
- random.randint(self.min_px_occlusion, self.max_px_occlusion),
- )
- crop_x, crop_y = (random.randint(0, img_w - crop_w), random.randint(0, img_h - crop_h))
- occlusion_value = img[..., crop_y : crop_y + crop_h, crop_x : crop_x + crop_w].mean(dim=(-2, -1), keepdim=True)
- return (crop_y, crop_x, crop_h, crop_w, occlusion_value)
- class RandomSpatialShift(torch.nn.Module):
- # This transform applies a vertical shift and a slight angle rotation and the same time
- def __init__(
- self, p: float = 0.5, max_angle: float = 0.1, max_px_shift: int = 2, interpolation_type: str = "bilinear"
- ) -> None:
- super().__init__()
- self.p = p
- self.max_angle = max_angle
- self.max_px_shift = max_px_shift
- self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: T_STEREO_TENSOR,
- masks: T_STEREO_TENSOR,
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- # the transform is applied only on the right image
- # in order to mimic slight calibration issues
- img_left, img_right = images
- INTERP_MODE = self._interpolation_mode_strategy()
- if torch.rand(1) < self.p:
- # [0, 1] -> [-a, a]
- shift = rand_float_range((1,), low=-self.max_px_shift, high=self.max_px_shift).item()
- angle = rand_float_range((1,), low=-self.max_angle, high=self.max_angle).item()
- # sample center point for the rotation matrix
- y = torch.randint(size=(1,), low=0, high=img_right.shape[-2]).item()
- x = torch.randint(size=(1,), low=0, high=img_right.shape[-1]).item()
- # apply affine transformations
- img_right = F.affine(
- img_right,
- angle=angle,
- translate=[0, shift], # translation only on the y-axis
- center=[x, y],
- scale=1.0,
- shear=0.0,
- interpolation=INTERP_MODE,
- )
- return ((img_left, img_right), disparities, masks)
- class RandomHorizontalFlip(torch.nn.Module):
- def __init__(self, p: float = 0.5) -> None:
- super().__init__()
- self.p = p
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- img_left, img_right = images
- dsp_left, dsp_right = disparities
- mask_left, mask_right = masks
- if dsp_right is not None and torch.rand(1) < self.p:
- img_left, img_right = F.hflip(img_left), F.hflip(img_right)
- dsp_left, dsp_right = F.hflip(dsp_left), F.hflip(dsp_right)
- if mask_left is not None and mask_right is not None:
- mask_left, mask_right = F.hflip(mask_left), F.hflip(mask_right)
- return ((img_right, img_left), (dsp_right, dsp_left), (mask_right, mask_left))
- return images, disparities, masks
- class Resize(torch.nn.Module):
- def __init__(self, resize_size: Tuple[int, ...], interpolation_type: str = "bilinear") -> None:
- super().__init__()
- self.resize_size = list(resize_size) # doing this to keep mypy happy
- self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- resized_images = ()
- resized_disparities = ()
- resized_masks = ()
- INTERP_MODE = self._interpolation_mode_strategy()
- for img in images:
- # We hard-code antialias=False to preserve results after we changed
- # its default from None to True (see
- # https://github.com/pytorch/vision/pull/7160)
- # TODO: we could re-train the stereo models with antialias=True?
- resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)
- for dsp in disparities:
- if dsp is not None:
- # rescale disparity to match the new image size
- scale_x = self.resize_size[1] / dsp.shape[-1]
- resized_disparities += (F.resize(dsp, self.resize_size, interpolation=INTERP_MODE) * scale_x,)
- else:
- resized_disparities += (None,)
- for mask in masks:
- if mask is not None:
- resized_masks += (
- # we squeeze and unsqueeze because the API requires > 3D tensors
- F.resize(
- mask.unsqueeze(0),
- self.resize_size,
- interpolation=F.InterpolationMode.NEAREST,
- ).squeeze(0),
- )
- else:
- resized_masks += (None,)
- return resized_images, resized_disparities, resized_masks
- class RandomRescaleAndCrop(torch.nn.Module):
- # This transform will resize the input with a given proba, and then crop it.
- # These are the reversed operations of the built-in RandomResizedCrop,
- # although the order of the operations doesn't matter too much: resizing a
- # crop would give the same result as cropping a resized image, up to
- # interpolation artifact at the borders of the output.
- #
- # The reason we don't rely on RandomResizedCrop is because of a significant
- # difference in the parametrization of both transforms, in particular,
- # because of the way the random parameters are sampled in both transforms,
- # which leads to fairly different results (and different epe). For more details see
- # https://github.com/pytorch/vision/pull/5026/files#r762932579
- def __init__(
- self,
- crop_size: Tuple[int, int],
- scale_range: Tuple[float, float] = (-0.2, 0.5),
- rescale_prob: float = 0.8,
- scaling_type: str = "exponential",
- interpolation_type: str = "bilinear",
- ) -> None:
- super().__init__()
- self.crop_size = crop_size
- self.min_scale = scale_range[0]
- self.max_scale = scale_range[1]
- self.rescale_prob = rescale_prob
- self.scaling_type = scaling_type
- self._interpolation_mode_strategy = InterpolationStrategy(interpolation_type)
- if self.scaling_type == "linear" and self.min_scale < 0:
- raise ValueError("min_scale must be >= 0 for linear scaling")
- def forward(
- self,
- images: T_STEREO_TENSOR,
- disparities: Tuple[T_FLOW, T_FLOW],
- masks: Tuple[T_MASK, T_MASK],
- ) -> Tuple[T_STEREO_TENSOR, Tuple[T_FLOW, T_FLOW], Tuple[T_MASK, T_MASK]]:
- img_left, img_right = images
- dsp_left, dsp_right = disparities
- mask_left, mask_right = masks
- INTERP_MODE = self._interpolation_mode_strategy()
- # randomly sample scale
- h, w = img_left.shape[-2:]
- # Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti)
- # It shouldn't matter much
- min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
- # exponential scaling will draw a random scale in (min_scale, max_scale) and then raise
- # 2 to the power of that random value. This final scale distribution will have a different
- # mean and variance than a uniform distribution. Note that a scale of 1 will result in
- # a rescaling of 2X the original size, whereas a scale of -1 will result in a rescaling
- # of 0.5X the original size.
- if self.scaling_type == "exponential":
- scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
- # linear scaling will draw a random scale in (min_scale, max_scale)
- elif self.scaling_type == "linear":
- scale = torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
- scale = max(scale, min_scale)
- new_h, new_w = round(h * scale), round(w * scale)
- if torch.rand(1).item() < self.rescale_prob:
- # rescale the images
- img_left = F.resize(img_left, size=(new_h, new_w), interpolation=INTERP_MODE)
- img_right = F.resize(img_right, size=(new_h, new_w), interpolation=INTERP_MODE)
- resized_masks, resized_disparities = (), ()
- for disparity, mask in zip(disparities, masks):
- if disparity is not None:
- if mask is None:
- resized_disparity = F.resize(disparity, size=(new_h, new_w), interpolation=INTERP_MODE)
- # rescale the disparity
- resized_disparity = (
- resized_disparity * torch.tensor([scale], device=resized_disparity.device)[:, None, None]
- )
- resized_mask = None
- else:
- resized_disparity, resized_mask = _resize_sparse_flow(
- disparity, mask, scale_x=scale, scale_y=scale
- )
- resized_masks += (resized_mask,)
- resized_disparities += (resized_disparity,)
- else:
- resized_disparities = disparities
- resized_masks = masks
- disparities = resized_disparities
- masks = resized_masks
- # Note: For sparse datasets (Kitti), the original code uses a "margin"
- # See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220
- # We don't, not sure if it matters much
- y0 = torch.randint(0, img_left.shape[1] - self.crop_size[0], size=(1,)).item()
- x0 = torch.randint(0, img_right.shape[2] - self.crop_size[1], size=(1,)).item()
- img_left = F.crop(img_left, y0, x0, self.crop_size[0], self.crop_size[1])
- img_right = F.crop(img_right, y0, x0, self.crop_size[0], self.crop_size[1])
- if dsp_left is not None:
- dsp_left = F.crop(disparities[0], y0, x0, self.crop_size[0], self.crop_size[1])
- if dsp_right is not None:
- dsp_right = F.crop(disparities[1], y0, x0, self.crop_size[0], self.crop_size[1])
- cropped_masks = ()
- for mask in masks:
- if mask is not None:
- mask = F.crop(mask, y0, x0, self.crop_size[0], self.crop_size[1])
- cropped_masks += (mask,)
- return ((img_left, img_right), (dsp_left, dsp_right), cropped_masks)
- def _resize_sparse_flow(
- flow: Tensor, valid_flow_mask: Tensor, scale_x: float = 1.0, scale_y: float = 0.0
- ) -> Tuple[Tensor, Tensor]:
- # This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse)
- # There are as-many non-zero values in the original flow as in the resized flow (up to OOB)
- # So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4
- h, w = flow.shape[-2:]
- h_new = int(round(h * scale_y))
- w_new = int(round(w * scale_x))
- flow_new = torch.zeros(size=[1, h_new, w_new], dtype=flow.dtype)
- valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype)
- jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy")
- ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask]
- ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long)
- jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long)
- within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new)
- ii_valid = ii_valid[within_bounds_mask]
- jj_valid = jj_valid[within_bounds_mask]
- ii_valid_new = ii_valid_new[within_bounds_mask]
- jj_valid_new = jj_valid_new[within_bounds_mask]
- valid_flow_new = flow[:, ii_valid, jj_valid]
- valid_flow_new *= scale_x
- flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new
- valid_new[ii_valid_new, jj_valid_new] = valid_flow_mask[ii_valid, jj_valid]
- return flow_new, valid_new.bool()
- class Compose(torch.nn.Module):
- def __init__(self, transforms: List[Callable]):
- super().__init__()
- self.transforms = transforms
- @torch.inference_mode()
- def forward(self, images, disparities, masks):
- for t in self.transforms:
- images, disparities, masks = t(images, disparities, masks)
- return images, disparities, masks
|