123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- from typing import Optional, Tuple, Union
- import torch
- import transforms as T
- class StereoMatchingEvalPreset(torch.nn.Module):
- def __init__(
- self,
- mean: float = 0.5,
- std: float = 0.5,
- resize_size: Optional[Tuple[int, ...]] = None,
- max_disparity: Optional[float] = None,
- interpolation_type: str = "bilinear",
- use_grayscale: bool = False,
- ) -> None:
- super().__init__()
- transforms = [
- T.ToTensor(),
- T.ConvertImageDtype(torch.float32),
- ]
- if use_grayscale:
- transforms.append(T.ConvertToGrayscale())
- if resize_size is not None:
- transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
- transforms.extend(
- [
- T.Normalize(mean=mean, std=std),
- T.MakeValidDisparityMask(max_disparity=max_disparity),
- T.ValidateModelInput(),
- ]
- )
- self.transforms = T.Compose(transforms)
- def forward(self, images, disparities, masks):
- return self.transforms(images, disparities, masks)
- class StereoMatchingTrainPreset(torch.nn.Module):
- def __init__(
- self,
- *,
- resize_size: Optional[Tuple[int, ...]],
- resize_interpolation_type: str = "bilinear",
- # RandomResizeAndCrop params
- crop_size: Tuple[int, int],
- rescale_prob: float = 1.0,
- scaling_type: str = "exponential",
- scale_range: Tuple[float, float] = (-0.2, 0.5),
- scale_interpolation_type: str = "bilinear",
- # convert to grayscale
- use_grayscale: bool = False,
- # normalization params
- mean: float = 0.5,
- std: float = 0.5,
- # processing device
- gpu_transforms: bool = False,
- # masking
- max_disparity: Optional[int] = 256,
- # SpatialShift params
- spatial_shift_prob: float = 0.5,
- spatial_shift_max_angle: float = 0.5,
- spatial_shift_max_displacement: float = 0.5,
- spatial_shift_interpolation_type: str = "bilinear",
- # AssymetricColorJitter
- gamma_range: Tuple[float, float] = (0.8, 1.2),
- brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
- contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
- saturation: Union[int, Tuple[int, int]] = 0.0,
- hue: Union[int, Tuple[int, int]] = 0.0,
- asymmetric_jitter_prob: float = 1.0,
- # RandomHorizontalFlip
- horizontal_flip_prob: float = 0.5,
- # RandomOcclusion
- occlusion_prob: float = 0.0,
- occlusion_px_range: Tuple[int, int] = (50, 100),
- # RandomErase
- erase_prob: float = 0.0,
- erase_px_range: Tuple[int, int] = (50, 100),
- erase_num_repeats: int = 1,
- ) -> None:
- if scaling_type not in ["linear", "exponential"]:
- raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
- super().__init__()
- transforms = [T.ToTensor()]
- # when fixing size across multiple datasets, we ensure
- # that the same size is used for all datasets when cropping
- if resize_size is not None:
- transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
- if gpu_transforms:
- transforms.append(T.ToGPU())
- # color handling
- color_transforms = [
- T.AsymmetricColorJitter(
- brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
- ),
- T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
- ]
- if use_grayscale:
- color_transforms.append(T.ConvertToGrayscale())
- transforms.extend(color_transforms)
- transforms.extend(
- [
- T.RandomSpatialShift(
- p=spatial_shift_prob,
- max_angle=spatial_shift_max_angle,
- max_px_shift=spatial_shift_max_displacement,
- interpolation_type=spatial_shift_interpolation_type,
- ),
- T.ConvertImageDtype(torch.float32),
- T.RandomRescaleAndCrop(
- crop_size=crop_size,
- scale_range=scale_range,
- rescale_prob=rescale_prob,
- scaling_type=scaling_type,
- interpolation_type=scale_interpolation_type,
- ),
- T.RandomHorizontalFlip(horizontal_flip_prob),
- # occlusion after flip, otherwise we're occluding the reference image
- T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
- T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
- T.Normalize(mean=mean, std=std),
- T.MakeValidDisparityMask(max_disparity),
- T.ValidateModelInput(),
- ]
- )
- self.transforms = T.Compose(transforms)
- def forward(self, images, disparties, mask):
- return self.transforms(images, disparties, mask)
|