presets.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from typing import Optional, Tuple, Union
  2. import torch
  3. import transforms as T
  4. class StereoMatchingEvalPreset(torch.nn.Module):
  5. def __init__(
  6. self,
  7. mean: float = 0.5,
  8. std: float = 0.5,
  9. resize_size: Optional[Tuple[int, ...]] = None,
  10. max_disparity: Optional[float] = None,
  11. interpolation_type: str = "bilinear",
  12. use_grayscale: bool = False,
  13. ) -> None:
  14. super().__init__()
  15. transforms = [
  16. T.ToTensor(),
  17. T.ConvertImageDtype(torch.float32),
  18. ]
  19. if use_grayscale:
  20. transforms.append(T.ConvertToGrayscale())
  21. if resize_size is not None:
  22. transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
  23. transforms.extend(
  24. [
  25. T.Normalize(mean=mean, std=std),
  26. T.MakeValidDisparityMask(max_disparity=max_disparity),
  27. T.ValidateModelInput(),
  28. ]
  29. )
  30. self.transforms = T.Compose(transforms)
  31. def forward(self, images, disparities, masks):
  32. return self.transforms(images, disparities, masks)
  33. class StereoMatchingTrainPreset(torch.nn.Module):
  34. def __init__(
  35. self,
  36. *,
  37. resize_size: Optional[Tuple[int, ...]],
  38. resize_interpolation_type: str = "bilinear",
  39. # RandomResizeAndCrop params
  40. crop_size: Tuple[int, int],
  41. rescale_prob: float = 1.0,
  42. scaling_type: str = "exponential",
  43. scale_range: Tuple[float, float] = (-0.2, 0.5),
  44. scale_interpolation_type: str = "bilinear",
  45. # convert to grayscale
  46. use_grayscale: bool = False,
  47. # normalization params
  48. mean: float = 0.5,
  49. std: float = 0.5,
  50. # processing device
  51. gpu_transforms: bool = False,
  52. # masking
  53. max_disparity: Optional[int] = 256,
  54. # SpatialShift params
  55. spatial_shift_prob: float = 0.5,
  56. spatial_shift_max_angle: float = 0.5,
  57. spatial_shift_max_displacement: float = 0.5,
  58. spatial_shift_interpolation_type: str = "bilinear",
  59. # AssymetricColorJitter
  60. gamma_range: Tuple[float, float] = (0.8, 1.2),
  61. brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
  62. contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
  63. saturation: Union[int, Tuple[int, int]] = 0.0,
  64. hue: Union[int, Tuple[int, int]] = 0.0,
  65. asymmetric_jitter_prob: float = 1.0,
  66. # RandomHorizontalFlip
  67. horizontal_flip_prob: float = 0.5,
  68. # RandomOcclusion
  69. occlusion_prob: float = 0.0,
  70. occlusion_px_range: Tuple[int, int] = (50, 100),
  71. # RandomErase
  72. erase_prob: float = 0.0,
  73. erase_px_range: Tuple[int, int] = (50, 100),
  74. erase_num_repeats: int = 1,
  75. ) -> None:
  76. if scaling_type not in ["linear", "exponential"]:
  77. raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
  78. super().__init__()
  79. transforms = [T.ToTensor()]
  80. # when fixing size across multiple datasets, we ensure
  81. # that the same size is used for all datasets when cropping
  82. if resize_size is not None:
  83. transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
  84. if gpu_transforms:
  85. transforms.append(T.ToGPU())
  86. # color handling
  87. color_transforms = [
  88. T.AsymmetricColorJitter(
  89. brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
  90. ),
  91. T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
  92. ]
  93. if use_grayscale:
  94. color_transforms.append(T.ConvertToGrayscale())
  95. transforms.extend(color_transforms)
  96. transforms.extend(
  97. [
  98. T.RandomSpatialShift(
  99. p=spatial_shift_prob,
  100. max_angle=spatial_shift_max_angle,
  101. max_px_shift=spatial_shift_max_displacement,
  102. interpolation_type=spatial_shift_interpolation_type,
  103. ),
  104. T.ConvertImageDtype(torch.float32),
  105. T.RandomRescaleAndCrop(
  106. crop_size=crop_size,
  107. scale_range=scale_range,
  108. rescale_prob=rescale_prob,
  109. scaling_type=scaling_type,
  110. interpolation_type=scale_interpolation_type,
  111. ),
  112. T.RandomHorizontalFlip(horizontal_flip_prob),
  113. # occlusion after flip, otherwise we're occluding the reference image
  114. T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
  115. T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
  116. T.Normalize(mean=mean, std=std),
  117. T.MakeValidDisparityMask(max_disparity),
  118. T.ValidateModelInput(),
  119. ]
  120. )
  121. self.transforms = T.Compose(transforms)
  122. def forward(self, images, disparties, mask):
  123. return self.transforms(images, disparties, mask)