parsing.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import argparse
  2. from functools import partial
  3. import torch
  4. from presets import StereoMatchingEvalPreset, StereoMatchingTrainPreset
  5. from torchvision.datasets import (
  6. CarlaStereo,
  7. CREStereo,
  8. ETH3DStereo,
  9. FallingThingsStereo,
  10. InStereo2k,
  11. Kitti2012Stereo,
  12. Kitti2015Stereo,
  13. Middlebury2014Stereo,
  14. SceneFlowStereo,
  15. SintelStereo,
  16. )
  17. VALID_DATASETS = {
  18. "crestereo": partial(CREStereo),
  19. "carla-highres": partial(CarlaStereo),
  20. "instereo2k": partial(InStereo2k),
  21. "sintel": partial(SintelStereo),
  22. "sceneflow-monkaa": partial(SceneFlowStereo, variant="Monkaa", pass_name="both"),
  23. "sceneflow-flyingthings": partial(SceneFlowStereo, variant="FlyingThings3D", pass_name="both"),
  24. "sceneflow-driving": partial(SceneFlowStereo, variant="Driving", pass_name="both"),
  25. "fallingthings": partial(FallingThingsStereo, variant="both"),
  26. "eth3d-train": partial(ETH3DStereo, split="train"),
  27. "eth3d-test": partial(ETH3DStereo, split="test"),
  28. "kitti2015-train": partial(Kitti2015Stereo, split="train"),
  29. "kitti2015-test": partial(Kitti2015Stereo, split="test"),
  30. "kitti2012-train": partial(Kitti2012Stereo, split="train"),
  31. "kitti2012-test": partial(Kitti2012Stereo, split="train"),
  32. "middlebury2014-other": partial(
  33. Middlebury2014Stereo, split="additional", use_ambient_view=True, calibration="both"
  34. ),
  35. "middlebury2014-train": partial(Middlebury2014Stereo, split="train", calibration="perfect"),
  36. "middlebury2014-test": partial(Middlebury2014Stereo, split="test", calibration=None),
  37. "middlebury2014-train-ambient": partial(
  38. Middlebury2014Stereo, split="train", use_ambient_views=True, calibrartion="perfect"
  39. ),
  40. }
  41. def make_train_transform(args: argparse.Namespace) -> torch.nn.Module:
  42. return StereoMatchingTrainPreset(
  43. resize_size=args.resize_size,
  44. crop_size=args.crop_size,
  45. rescale_prob=args.rescale_prob,
  46. scaling_type=args.scaling_type,
  47. scale_range=args.scale_range,
  48. scale_interpolation_type=args.interpolation_strategy,
  49. use_grayscale=args.use_grayscale,
  50. mean=args.norm_mean,
  51. std=args.norm_std,
  52. horizontal_flip_prob=args.flip_prob,
  53. gpu_transforms=args.gpu_transforms,
  54. max_disparity=args.max_disparity,
  55. spatial_shift_prob=args.spatial_shift_prob,
  56. spatial_shift_max_angle=args.spatial_shift_max_angle,
  57. spatial_shift_max_displacement=args.spatial_shift_max_displacement,
  58. spatial_shift_interpolation_type=args.interpolation_strategy,
  59. gamma_range=args.gamma_range,
  60. brightness=args.brightness_range,
  61. contrast=args.contrast_range,
  62. saturation=args.saturation_range,
  63. hue=args.hue_range,
  64. asymmetric_jitter_prob=args.asymmetric_jitter_prob,
  65. )
  66. def make_eval_transform(args: argparse.Namespace) -> torch.nn.Module:
  67. if args.eval_size is None:
  68. resize_size = args.crop_size
  69. else:
  70. resize_size = args.eval_size
  71. return StereoMatchingEvalPreset(
  72. mean=args.norm_mean,
  73. std=args.norm_std,
  74. use_grayscale=args.use_grayscale,
  75. resize_size=resize_size,
  76. interpolation_type=args.interpolation_strategy,
  77. )
  78. def make_dataset(dataset_name: str, dataset_root: str, transforms: torch.nn.Module) -> torch.utils.data.Dataset:
  79. return VALID_DATASETS[dataset_name](root=dataset_root, transforms=transforms)